dfa-gym 0.1.0__tar.gz → 0.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dfa_gym-0.1.0 → dfa_gym-0.2.0}/.gitignore +3 -0
- dfa_gym-0.2.0/.python-version +1 -0
- dfa_gym-0.2.0/PKG-INFO +93 -0
- dfa_gym-0.2.0/README.md +83 -0
- dfa_gym-0.2.0/dfa_gym/__init__.py +6 -0
- dfa_gym-0.2.0/dfa_gym/dfa_bisim_env.py +121 -0
- dfa_gym-0.2.0/dfa_gym/dfa_wrapper.py +190 -0
- dfa_gym-0.2.0/dfa_gym/env.py +168 -0
- dfa_gym-0.2.0/dfa_gym/maps/2buttons_2agents.pdf +0 -0
- dfa_gym-0.2.0/dfa_gym/maps/2rooms_2agents.pdf +0 -0
- dfa_gym-0.2.0/dfa_gym/maps/4buttons_4agents.pdf +0 -0
- dfa_gym-0.2.0/dfa_gym/maps/4rooms_4agents.pdf +0 -0
- dfa_gym-0.2.0/dfa_gym/robot.png +0 -0
- dfa_gym-0.2.0/dfa_gym/spaces.py +156 -0
- dfa_gym-0.2.0/dfa_gym/token_env.py +571 -0
- dfa_gym-0.2.0/dfa_gym/utils.py +266 -0
- dfa_gym-0.2.0/pyproject.toml +16 -0
- dfa_gym-0.2.0/requirements.txt +46 -0
- dfa_gym-0.2.0/test.py +137 -0
- dfa_gym-0.2.0/uv.lock +7 -0
- dfa_gym-0.1.0/.python-version +0 -1
- dfa_gym-0.1.0/PKG-INFO +0 -11
- dfa_gym-0.1.0/README.md +0 -1
- dfa_gym-0.1.0/dfa_gym/__init__.py +0 -16
- dfa_gym-0.1.0/dfa_gym/dfa_env.py +0 -45
- dfa_gym-0.1.0/dfa_gym/dfa_wrapper.py +0 -57
- dfa_gym-0.1.0/pyproject.toml +0 -14
- dfa_gym-0.1.0/test.py +0 -26
- dfa_gym-0.1.0/uv.lock +0 -181
- {dfa_gym-0.1.0 → dfa_gym-0.2.0}/LICENSE +0 -0
|
@@ -0,0 +1 @@
|
|
|
1
|
+
3.10
|
dfa_gym-0.2.0/PKG-INFO
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: dfa-gym
|
|
3
|
+
Version: 0.2.0
|
|
4
|
+
Summary: Python library for playing DFA bisimulation games and wrapping other RL environments with DFA goals.
|
|
5
|
+
Author-email: Beyazit Yalcinkaya <beyazit@berkeley.edu>
|
|
6
|
+
License-File: LICENSE
|
|
7
|
+
Requires-Python: >=3.10
|
|
8
|
+
Requires-Dist: dfax>=0.1.1
|
|
9
|
+
Description-Content-Type: text/markdown
|
|
10
|
+
|
|
11
|
+
# dfa-gym
|
|
12
|
+
|
|
13
|
+
This repo implements (Multi-Agent) Reinforcement Learning environments in JAX for solving objectives given as Deteministic Finite Automata (DFAs). There are three environments:
|
|
14
|
+
|
|
15
|
+
1. `TokenEnv` is a fully observable grid environment with tokens in cells. The grid can be created randomly or from a specific layout. It can be instantiated in both single- and multi-agent settings.
|
|
16
|
+
2. `DFAWrapper` is an environment wrapper assigning tasks represented as Deterministic Finite Automata (DFAs) to the agents in the wrapped environment. DFAs are repsented as [`DFAx`](https://github.com/rad-dfa/dfax) objects.
|
|
17
|
+
3. `DFABisimEnv` is an environment for solving DFA bisimulation games to learn RAD Embeddings, provably correct latent DFA representation, as described in [this paper](https://arxiv.org/pdf/2503.05042).
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
## Installation
|
|
21
|
+
|
|
22
|
+
This package will soon be made pip-installable. In the meantime, pull the repo and and install locally.
|
|
23
|
+
|
|
24
|
+
```
|
|
25
|
+
git clone https://github.com/rad-dfa/dfa-gym.git
|
|
26
|
+
pip install -e dfa-gym
|
|
27
|
+
```
|
|
28
|
+
|
|
29
|
+
## TokenEnv
|
|
30
|
+
|
|
31
|
+
Create a grid world with token and agent positions assigned randomly.
|
|
32
|
+
|
|
33
|
+
```python
|
|
34
|
+
from dfa_gym import TokenEnv
|
|
35
|
+
|
|
36
|
+
env = TokenEnv(
|
|
37
|
+
n_agents=1, # Single agent
|
|
38
|
+
n_tokens=10, # 10 different token types
|
|
39
|
+
n_token_repeat=2, # Each token repeated twice
|
|
40
|
+
grid_shape=(7, 7), # Shape of the grid
|
|
41
|
+
fixed_map_seed=None, # If not None, then samples the same map using the given seed
|
|
42
|
+
max_steps_in_episode=100, # Episode length is 100
|
|
43
|
+
)
|
|
44
|
+
```
|
|
45
|
+
|
|
46
|
+
Create a grid world from a given layout.
|
|
47
|
+
|
|
48
|
+
```python
|
|
49
|
+
layout = """
|
|
50
|
+
[ 0 ][ ][ ][ ][ # ][ # ][ # ][ # ][ # ]
|
|
51
|
+
[ ][ ][ a ][ ][#,a][ 0 ][ ][ 2 ][ # ]
|
|
52
|
+
[ A ][ ][ a ][ ][#,a][ ][ 8 ][ ][ # ]
|
|
53
|
+
[ ][ ][ a ][ ][#,a][ 6 ][ ][ 4 ][ # ]
|
|
54
|
+
[ 1 ][ ][ ][ 3 ][ # ][ # ][ # ][ # ][ # ]
|
|
55
|
+
[ ][ ][ b ][ ][#,b][ 1 ][ ][ 3 ][ # ]
|
|
56
|
+
[ B ][ ][ b ][ ][#,b][ ][ 9 ][ ][ # ]
|
|
57
|
+
[ ][ ][ b ][ ][#,b][ 7 ][ ][ 5 ][ # ]
|
|
58
|
+
[ 2 ][ ][ ][ ][ # ][ # ][ # ][ # ][ # ]
|
|
59
|
+
"""
|
|
60
|
+
env = TokenEnv(
|
|
61
|
+
layout=layout, # Set layout, where each [] indicates a cell, uppercase letters are
|
|
62
|
+
# agents, # are walls, and lower case letters are buttons when alone
|
|
63
|
+
# and doors when paired with a wall. For example, [#,a] is a door
|
|
64
|
+
# that is open if an agent is on a [ a ] cell and closed otherwise.
|
|
65
|
+
)
|
|
66
|
+
```
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
## DFAWrapper
|
|
70
|
+
|
|
71
|
+
Wrap a `TokenEnv` instance using `DFAWrapper `.
|
|
72
|
+
|
|
73
|
+
```python
|
|
74
|
+
from dfa_gym import DFAWrapper
|
|
75
|
+
from dfax.samplers import ReachSampler
|
|
76
|
+
|
|
77
|
+
env = DFAWrapper(
|
|
78
|
+
env=TokenEnv(layout=layout),
|
|
79
|
+
sampler=ReachSampler()
|
|
80
|
+
)
|
|
81
|
+
```
|
|
82
|
+
|
|
83
|
+
## DFABisimEnv
|
|
84
|
+
|
|
85
|
+
Create DFA bisimulation game.
|
|
86
|
+
|
|
87
|
+
```python
|
|
88
|
+
from dfa_gym import DFABisimEnv
|
|
89
|
+
from dfax.samplers import RADSampler
|
|
90
|
+
|
|
91
|
+
env = DFABisimEnv(sampler=RADSampler())
|
|
92
|
+
```
|
|
93
|
+
|
dfa_gym-0.2.0/README.md
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
# dfa-gym
|
|
2
|
+
|
|
3
|
+
This repo implements (Multi-Agent) Reinforcement Learning environments in JAX for solving objectives given as Deteministic Finite Automata (DFAs). There are three environments:
|
|
4
|
+
|
|
5
|
+
1. `TokenEnv` is a fully observable grid environment with tokens in cells. The grid can be created randomly or from a specific layout. It can be instantiated in both single- and multi-agent settings.
|
|
6
|
+
2. `DFAWrapper` is an environment wrapper assigning tasks represented as Deterministic Finite Automata (DFAs) to the agents in the wrapped environment. DFAs are repsented as [`DFAx`](https://github.com/rad-dfa/dfax) objects.
|
|
7
|
+
3. `DFABisimEnv` is an environment for solving DFA bisimulation games to learn RAD Embeddings, provably correct latent DFA representation, as described in [this paper](https://arxiv.org/pdf/2503.05042).
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
## Installation
|
|
11
|
+
|
|
12
|
+
This package will soon be made pip-installable. In the meantime, pull the repo and and install locally.
|
|
13
|
+
|
|
14
|
+
```
|
|
15
|
+
git clone https://github.com/rad-dfa/dfa-gym.git
|
|
16
|
+
pip install -e dfa-gym
|
|
17
|
+
```
|
|
18
|
+
|
|
19
|
+
## TokenEnv
|
|
20
|
+
|
|
21
|
+
Create a grid world with token and agent positions assigned randomly.
|
|
22
|
+
|
|
23
|
+
```python
|
|
24
|
+
from dfa_gym import TokenEnv
|
|
25
|
+
|
|
26
|
+
env = TokenEnv(
|
|
27
|
+
n_agents=1, # Single agent
|
|
28
|
+
n_tokens=10, # 10 different token types
|
|
29
|
+
n_token_repeat=2, # Each token repeated twice
|
|
30
|
+
grid_shape=(7, 7), # Shape of the grid
|
|
31
|
+
fixed_map_seed=None, # If not None, then samples the same map using the given seed
|
|
32
|
+
max_steps_in_episode=100, # Episode length is 100
|
|
33
|
+
)
|
|
34
|
+
```
|
|
35
|
+
|
|
36
|
+
Create a grid world from a given layout.
|
|
37
|
+
|
|
38
|
+
```python
|
|
39
|
+
layout = """
|
|
40
|
+
[ 0 ][ ][ ][ ][ # ][ # ][ # ][ # ][ # ]
|
|
41
|
+
[ ][ ][ a ][ ][#,a][ 0 ][ ][ 2 ][ # ]
|
|
42
|
+
[ A ][ ][ a ][ ][#,a][ ][ 8 ][ ][ # ]
|
|
43
|
+
[ ][ ][ a ][ ][#,a][ 6 ][ ][ 4 ][ # ]
|
|
44
|
+
[ 1 ][ ][ ][ 3 ][ # ][ # ][ # ][ # ][ # ]
|
|
45
|
+
[ ][ ][ b ][ ][#,b][ 1 ][ ][ 3 ][ # ]
|
|
46
|
+
[ B ][ ][ b ][ ][#,b][ ][ 9 ][ ][ # ]
|
|
47
|
+
[ ][ ][ b ][ ][#,b][ 7 ][ ][ 5 ][ # ]
|
|
48
|
+
[ 2 ][ ][ ][ ][ # ][ # ][ # ][ # ][ # ]
|
|
49
|
+
"""
|
|
50
|
+
env = TokenEnv(
|
|
51
|
+
layout=layout, # Set layout, where each [] indicates a cell, uppercase letters are
|
|
52
|
+
# agents, # are walls, and lower case letters are buttons when alone
|
|
53
|
+
# and doors when paired with a wall. For example, [#,a] is a door
|
|
54
|
+
# that is open if an agent is on a [ a ] cell and closed otherwise.
|
|
55
|
+
)
|
|
56
|
+
```
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
## DFAWrapper
|
|
60
|
+
|
|
61
|
+
Wrap a `TokenEnv` instance using `DFAWrapper `.
|
|
62
|
+
|
|
63
|
+
```python
|
|
64
|
+
from dfa_gym import DFAWrapper
|
|
65
|
+
from dfax.samplers import ReachSampler
|
|
66
|
+
|
|
67
|
+
env = DFAWrapper(
|
|
68
|
+
env=TokenEnv(layout=layout),
|
|
69
|
+
sampler=ReachSampler()
|
|
70
|
+
)
|
|
71
|
+
```
|
|
72
|
+
|
|
73
|
+
## DFABisimEnv
|
|
74
|
+
|
|
75
|
+
Create DFA bisimulation game.
|
|
76
|
+
|
|
77
|
+
```python
|
|
78
|
+
from dfa_gym import DFABisimEnv
|
|
79
|
+
from dfax.samplers import RADSampler
|
|
80
|
+
|
|
81
|
+
env = DFABisimEnv(sampler=RADSampler())
|
|
82
|
+
```
|
|
83
|
+
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
import jax
|
|
2
|
+
import dfax
|
|
3
|
+
import chex
|
|
4
|
+
import jax.numpy as jnp
|
|
5
|
+
from flax import struct
|
|
6
|
+
from functools import partial
|
|
7
|
+
from typing import Tuple, Dict
|
|
8
|
+
from dfa_gym import spaces
|
|
9
|
+
from dfa_gym.env import MultiAgentEnv, State
|
|
10
|
+
from dfax.samplers import DFASampler, RADSampler
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@struct.dataclass
|
|
14
|
+
class DFABisimState(State):
|
|
15
|
+
dfa_l: dfax.DFAx
|
|
16
|
+
dfa_r: dfax.DFAx
|
|
17
|
+
time: int
|
|
18
|
+
|
|
19
|
+
class DFABisimEnv(MultiAgentEnv):
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
sampler: DFASampler = RADSampler(),
|
|
24
|
+
max_steps_in_episode: int = 100
|
|
25
|
+
) -> None:
|
|
26
|
+
super().__init__(num_agents=1)
|
|
27
|
+
self.n_agents = self.num_agents
|
|
28
|
+
self.sampler = sampler
|
|
29
|
+
self.max_steps_in_episode = max_steps_in_episode
|
|
30
|
+
|
|
31
|
+
self.agents = [f"agent_{i}" for i in range(self.n_agents)]
|
|
32
|
+
|
|
33
|
+
self.action_spaces = {
|
|
34
|
+
agent: spaces.Discrete(self.sampler.n_tokens)
|
|
35
|
+
for agent in self.agents
|
|
36
|
+
}
|
|
37
|
+
max_dfa_size = self.sampler.max_size
|
|
38
|
+
n_tokens = self.sampler.n_tokens
|
|
39
|
+
self.observation_spaces = {
|
|
40
|
+
agent: spaces.Dict({
|
|
41
|
+
"graph_l": spaces.Dict({
|
|
42
|
+
"node_features": spaces.Box(low=0, high=1, shape=(max_dfa_size, 4), dtype=jnp.uint16),
|
|
43
|
+
"edge_features": spaces.Box(low=0, high=1, shape=(max_dfa_size*max_dfa_size, n_tokens + 8), dtype=jnp.uint16),
|
|
44
|
+
"edge_index": spaces.Box(low=0, high=max_dfa_size, shape=(2, max_dfa_size*max_dfa_size), dtype=jnp.uint16),
|
|
45
|
+
"current_state": spaces.Box(low=0, high=max_dfa_size, shape=(1,), dtype=jnp.uint16),
|
|
46
|
+
"n_states": spaces.Box(low=0, high=max_dfa_size, shape=(max_dfa_size,), dtype=jnp.uint16)
|
|
47
|
+
}),
|
|
48
|
+
"graph_r": spaces.Dict({
|
|
49
|
+
"node_features": spaces.Box(low=0, high=1, shape=(max_dfa_size, 4), dtype=jnp.uint16),
|
|
50
|
+
"edge_features": spaces.Box(low=0, high=1, shape=(max_dfa_size*max_dfa_size, n_tokens + 8), dtype=jnp.uint16),
|
|
51
|
+
"edge_index": spaces.Box(low=0, high=max_dfa_size, shape=(2, max_dfa_size*max_dfa_size), dtype=jnp.uint16),
|
|
52
|
+
"current_state": spaces.Box(low=0, high=max_dfa_size, shape=(1,), dtype=jnp.uint16),
|
|
53
|
+
"n_states": spaces.Box(low=0, high=max_dfa_size, shape=(max_dfa_size,), dtype=jnp.uint16)
|
|
54
|
+
})
|
|
55
|
+
})
|
|
56
|
+
for agent in self.agents
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
@partial(jax.jit, static_argnums=(0,))
|
|
60
|
+
def reset(
|
|
61
|
+
self,
|
|
62
|
+
key: chex.PRNGKey
|
|
63
|
+
) -> Tuple[Dict[str, chex.Array], DFABisimState]:
|
|
64
|
+
|
|
65
|
+
def cond_fn(carry):
|
|
66
|
+
_, dfa_l, dfa_r = carry
|
|
67
|
+
return dfa_l == dfa_r
|
|
68
|
+
|
|
69
|
+
def body_fn(carry):
|
|
70
|
+
key, _, _ = carry
|
|
71
|
+
key, kl, kr = jax.random.split(key, 3)
|
|
72
|
+
dfa_l = self.sampler.sample(kl)
|
|
73
|
+
dfa_r = self.sampler.sample(kr)
|
|
74
|
+
return (key, dfa_l, dfa_r)
|
|
75
|
+
|
|
76
|
+
init_carry = body_fn((key, None, None))
|
|
77
|
+
_, dfa_l, dfa_r = jax.lax.while_loop(cond_fn, body_fn, init_carry)
|
|
78
|
+
|
|
79
|
+
state = DFABisimState(dfa_l=dfa_l, dfa_r=dfa_r, time=0)
|
|
80
|
+
obs = self.get_obs(state=state)
|
|
81
|
+
|
|
82
|
+
return {self.agents[0]: obs}, state
|
|
83
|
+
|
|
84
|
+
@partial(jax.jit, static_argnums=(0,))
|
|
85
|
+
def step_env(
|
|
86
|
+
self,
|
|
87
|
+
key: chex.PRNGKey,
|
|
88
|
+
state: DFABisimState,
|
|
89
|
+
action: int
|
|
90
|
+
) -> Tuple[Dict[str, chex.Array], DFABisimState, Dict[str, float], Dict[str, bool], Dict]:
|
|
91
|
+
|
|
92
|
+
dfa_l = state.dfa_l.advance(action[self.agents[0]]).minimize()
|
|
93
|
+
dfa_r = state.dfa_r.advance(action[self.agents[0]]).minimize()
|
|
94
|
+
|
|
95
|
+
reward_l = dfa_l.reward(binary=False)
|
|
96
|
+
reward_r = dfa_r.reward(binary=False)
|
|
97
|
+
reward = reward_l - reward_r
|
|
98
|
+
|
|
99
|
+
new_state = DFABisimState(
|
|
100
|
+
dfa_l=dfa_l,
|
|
101
|
+
dfa_r=dfa_r,
|
|
102
|
+
time=state.time+1
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
done = jnp.logical_or(jnp.logical_or(dfa_l.n_states <= 1, dfa_r.n_states <= 1), new_state.time >= self.max_steps_in_episode)
|
|
106
|
+
|
|
107
|
+
obs = self.get_obs(state=new_state)
|
|
108
|
+
info = {}
|
|
109
|
+
|
|
110
|
+
return {self.agents[0]: obs}, new_state, {self.agents[0]: reward}, {self.agents[0]: done, "__all__": done}, info
|
|
111
|
+
|
|
112
|
+
@partial(jax.jit, static_argnums=(0,))
|
|
113
|
+
def get_obs(
|
|
114
|
+
self,
|
|
115
|
+
state: DFABisimState
|
|
116
|
+
) -> Dict[str, chex.Array]:
|
|
117
|
+
return {
|
|
118
|
+
"graph_l": state.dfa_l.to_graph(),
|
|
119
|
+
"graph_r": state.dfa_r.to_graph()
|
|
120
|
+
}
|
|
121
|
+
|
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
import jax
|
|
2
|
+
import dfax
|
|
3
|
+
import chex
|
|
4
|
+
import jax.numpy as jnp
|
|
5
|
+
from flax import struct
|
|
6
|
+
from dfa_gym import spaces
|
|
7
|
+
from functools import partial
|
|
8
|
+
from typing import Tuple, Dict, Callable
|
|
9
|
+
from dfax.utils import list2batch, batch2graph
|
|
10
|
+
from dfa_gym.env import MultiAgentEnv, State
|
|
11
|
+
from dfax.samplers import DFASampler, RADSampler
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@struct.dataclass
|
|
15
|
+
class DFAWrapperState(State):
|
|
16
|
+
dfas: Dict[str, dfax.DFAx]
|
|
17
|
+
init_dfas: Dict[str, dfax.DFAx]
|
|
18
|
+
env_obs: chex.Array
|
|
19
|
+
env_state: State
|
|
20
|
+
|
|
21
|
+
class DFAWrapper(MultiAgentEnv):
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
env: MultiAgentEnv,
|
|
26
|
+
gamma: float | None = None,
|
|
27
|
+
sampler: DFASampler = RADSampler(),
|
|
28
|
+
binary_reward: bool = True,
|
|
29
|
+
progress: bool = True,
|
|
30
|
+
) -> None:
|
|
31
|
+
super().__init__(num_agents=env.num_agents)
|
|
32
|
+
self.env = env
|
|
33
|
+
self.gamma = gamma
|
|
34
|
+
self.sampler = sampler
|
|
35
|
+
self.binary_reward = binary_reward
|
|
36
|
+
self.progress = progress
|
|
37
|
+
|
|
38
|
+
assert self.sampler.n_tokens == self.env.n_tokens
|
|
39
|
+
|
|
40
|
+
self.agents = [f"agent_{i}" for i in range(self.num_agents)]
|
|
41
|
+
|
|
42
|
+
self.action_spaces = {
|
|
43
|
+
agent: self.env.action_space(agent)
|
|
44
|
+
for agent in self.agents
|
|
45
|
+
}
|
|
46
|
+
max_dfa_size = self.sampler.max_size
|
|
47
|
+
n_tokens = self.sampler.n_tokens
|
|
48
|
+
self.observation_spaces = {
|
|
49
|
+
agent: spaces.Dict({
|
|
50
|
+
"_id": spaces.Discrete(self.num_agents),
|
|
51
|
+
"obs": self.env.observation_space(agent),
|
|
52
|
+
"dfa": spaces.Dict({
|
|
53
|
+
"node_features": spaces.Box(low=0, high=1, shape=(max_dfa_size*self.num_agents, 4), dtype=jnp.float32),
|
|
54
|
+
"edge_features": spaces.Box(low=0, high=1, shape=(max_dfa_size*self.num_agents*max_dfa_size*self.num_agents, n_tokens + 8), dtype=jnp.float32),
|
|
55
|
+
"edge_index": spaces.Box(low=0, high=max_dfa_size*self.num_agents, shape=(2, max_dfa_size*self.num_agents*max_dfa_size*self.num_agents), dtype=jnp.int32),
|
|
56
|
+
"current_state": spaces.Box(low=0, high=max_dfa_size*self.num_agents, shape=(self.num_agents,), dtype=jnp.int32),
|
|
57
|
+
"n_states": spaces.Box(low=0, high=max_dfa_size*self.num_agents, shape=(max_dfa_size*self.num_agents,), dtype=jnp.int32)
|
|
58
|
+
}),
|
|
59
|
+
})
|
|
60
|
+
for agent in self.agents
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
@partial(jax.jit, static_argnums=(0,))
|
|
64
|
+
def reset(
|
|
65
|
+
self,
|
|
66
|
+
key: chex.PRNGKey
|
|
67
|
+
) -> Tuple[Dict[str, chex.Array], DFAWrapperState]:
|
|
68
|
+
keys = jax.random.split(key, 4 + self.num_agents)
|
|
69
|
+
|
|
70
|
+
env_obs, env_state = self.env.reset(keys[1])
|
|
71
|
+
|
|
72
|
+
n_trivial = jax.random.choice(keys[2], self.num_agents)
|
|
73
|
+
mask = jax.random.permutation(keys[3], jnp.arange(self.num_agents) < n_trivial)
|
|
74
|
+
|
|
75
|
+
def sample_dfa(dfa_key, sample_trivial):
|
|
76
|
+
return jax.tree_util.tree_map(
|
|
77
|
+
lambda t, s: jnp.where(sample_trivial, t, s),
|
|
78
|
+
self.sampler.trivial(True),
|
|
79
|
+
self.sampler.sample(dfa_key)
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
dfas_tree = jax.vmap(sample_dfa)(keys[4:], mask)
|
|
83
|
+
|
|
84
|
+
dfas = {
|
|
85
|
+
agent: jax.tree_util.tree_map(lambda x: x[i], dfas_tree)
|
|
86
|
+
for i, agent in enumerate(self.agents)
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
state = DFAWrapperState(
|
|
90
|
+
dfas=dfas,
|
|
91
|
+
init_dfas={agent: dfas[agent] for agent in self.agents},
|
|
92
|
+
env_obs=env_obs,
|
|
93
|
+
env_state=env_state
|
|
94
|
+
)
|
|
95
|
+
obs = self.get_obs(state=state)
|
|
96
|
+
|
|
97
|
+
return obs, state
|
|
98
|
+
|
|
99
|
+
@partial(jax.jit, static_argnums=(0,))
|
|
100
|
+
def step_env(
|
|
101
|
+
self,
|
|
102
|
+
key: chex.PRNGKey,
|
|
103
|
+
state: DFAWrapperState,
|
|
104
|
+
action: int,
|
|
105
|
+
) -> Tuple[Dict[str, chex.Array], DFAWrapperState, Dict[str, float], Dict[str, bool], Dict]:
|
|
106
|
+
|
|
107
|
+
env_obs, env_state, env_rewards, env_dones, env_info = self.env.step_env(key, state.env_state, action)
|
|
108
|
+
|
|
109
|
+
symbols = self.env.label_f(env_state)
|
|
110
|
+
|
|
111
|
+
dfas = {
|
|
112
|
+
agent: state.dfas[agent].advance(symbols[agent]).minimize()
|
|
113
|
+
for agent in self.agents
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
dones = {
|
|
117
|
+
agent: jnp.logical_or(env_dones[agent], dfas[agent].n_states <= 1)
|
|
118
|
+
for agent in self.agents
|
|
119
|
+
}
|
|
120
|
+
_dones = jnp.array([dones[agent] for agent in self.agents])
|
|
121
|
+
dones.update({"__all__": jnp.all(_dones)})
|
|
122
|
+
|
|
123
|
+
dfa_rewards_min = jnp.min(jnp.array([dfas[agent].reward(binary=self.binary_reward) for agent in self.agents]))
|
|
124
|
+
rewards = {
|
|
125
|
+
agent: jax.lax.cond(
|
|
126
|
+
dones["__all__"],
|
|
127
|
+
lambda _: env_rewards[agent] + dfa_rewards_min,
|
|
128
|
+
lambda _: env_rewards[agent],
|
|
129
|
+
operand=None
|
|
130
|
+
)
|
|
131
|
+
for agent in self.agents
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
if self.gamma is not None:
|
|
135
|
+
rewards = {
|
|
136
|
+
agent: rewards[agent] + self.gamma * dfas[agent].reward(binary=self.binary_reward) - state.dfas[agent].reward(binary=self.binary_reward)
|
|
137
|
+
for agent in self.agents
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
infos = {}
|
|
141
|
+
|
|
142
|
+
state = DFAWrapperState(
|
|
143
|
+
dfas=dfas,
|
|
144
|
+
init_dfas=state.init_dfas,
|
|
145
|
+
env_obs=env_obs,
|
|
146
|
+
env_state=env_state
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
obs = self.get_obs(state=state)
|
|
150
|
+
|
|
151
|
+
return obs, state, rewards, dones, infos
|
|
152
|
+
|
|
153
|
+
@partial(jax.jit, static_argnums=(0,))
|
|
154
|
+
def get_obs(
|
|
155
|
+
self,
|
|
156
|
+
state: DFAWrapperState
|
|
157
|
+
) -> Dict[str, chex.Array]:
|
|
158
|
+
if self.progress:
|
|
159
|
+
dfas = batch2graph(
|
|
160
|
+
list2batch(
|
|
161
|
+
[state.dfas[agent].to_graph() for agent in self.agents]
|
|
162
|
+
)
|
|
163
|
+
)
|
|
164
|
+
else:
|
|
165
|
+
dfas = batch2graph(
|
|
166
|
+
list2batch(
|
|
167
|
+
[state.init_dfas[agent].to_graph() for agent in self.agents]
|
|
168
|
+
)
|
|
169
|
+
)
|
|
170
|
+
return {
|
|
171
|
+
agent: {
|
|
172
|
+
"_id": i,
|
|
173
|
+
"obs": state.env_obs[agent],
|
|
174
|
+
"dfa": dfas
|
|
175
|
+
}
|
|
176
|
+
for i, agent in enumerate(self.agents)
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
def render(self, state: DFAWrapperState):
|
|
180
|
+
out = ""
|
|
181
|
+
for agent in self.agents:
|
|
182
|
+
out += "****\n"
|
|
183
|
+
out += f"{agent}'s DFA:\n"
|
|
184
|
+
if self.progress:
|
|
185
|
+
out += f"{state.dfas[agent]}\n"
|
|
186
|
+
else:
|
|
187
|
+
out += f"{state.init_dfas[agent]}\n"
|
|
188
|
+
self.env.render(state.env_state)
|
|
189
|
+
print(out)
|
|
190
|
+
|
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Abstract base class for multi agent gym environments with JAX
|
|
3
|
+
Based on the JaxMARL APIs
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import jax
|
|
7
|
+
import jax.numpy as jnp
|
|
8
|
+
from typing import Dict
|
|
9
|
+
import chex
|
|
10
|
+
from functools import partial
|
|
11
|
+
from flax import struct
|
|
12
|
+
from typing import Tuple, Optional
|
|
13
|
+
|
|
14
|
+
from dfa_gym.spaces import Space
|
|
15
|
+
|
|
16
|
+
@struct.dataclass
|
|
17
|
+
class State:
|
|
18
|
+
pass
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class MultiAgentEnv(object):
|
|
22
|
+
"""Jittable abstract base class for all JaxMARL Environments."""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
num_agents: int,
|
|
27
|
+
) -> None:
|
|
28
|
+
"""
|
|
29
|
+
Args:
|
|
30
|
+
num_agents (int): maximum number of agents within the environment, used to set array dimensions
|
|
31
|
+
"""
|
|
32
|
+
self.num_agents = num_agents
|
|
33
|
+
self.observation_spaces = dict()
|
|
34
|
+
self.action_spaces = dict()
|
|
35
|
+
|
|
36
|
+
@partial(jax.jit, static_argnums=(0,))
|
|
37
|
+
def reset(self, key: chex.PRNGKey) -> Tuple[Dict[str, chex.Array], State]:
|
|
38
|
+
"""Performs resetting of the environment.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
key (chex.PRNGKey): random key
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
Observations (Dict[str, chex.Array]): observations for each agent, keyed by agent name
|
|
45
|
+
State (State): environment state
|
|
46
|
+
"""
|
|
47
|
+
raise NotImplementedError
|
|
48
|
+
|
|
49
|
+
@partial(jax.jit, static_argnums=(0,))
|
|
50
|
+
def step(
|
|
51
|
+
self,
|
|
52
|
+
key: chex.PRNGKey,
|
|
53
|
+
state: State,
|
|
54
|
+
actions: Dict[str, chex.Array],
|
|
55
|
+
reset_state: Optional[State] = None,
|
|
56
|
+
) -> Tuple[Dict[str, chex.Array], State, Dict[str, float], Dict[str, bool], Dict]:
|
|
57
|
+
"""Performs step transitions in the environment. Resets the environment if done.
|
|
58
|
+
To control the reset state, pass `reset_state`. Otherwise, the environment will reset using `self.reset`.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
key (chex.PRNGKey): random key
|
|
62
|
+
state (State): environment state
|
|
63
|
+
actions (Dict[str, chex.Array]): agent actions, keyed by agent name
|
|
64
|
+
reset_state (Optional[State], optional): Optional environment state to reset to on episode completion. Defaults to None.
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
Observations (Dict[str, chex.Array]): next observations
|
|
68
|
+
State (State): next environment state
|
|
69
|
+
Rewards (Dict[str, float]): rewards, keyed by agent name
|
|
70
|
+
Dones (Dict[str, bool]): dones, keyed by agent name:
|
|
71
|
+
Info (Dict): info dictionary
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
key, key_reset = jax.random.split(key)
|
|
75
|
+
obs_st, states_st, rewards, dones, infos = self.step_env(key, state, actions)
|
|
76
|
+
|
|
77
|
+
if reset_state is None:
|
|
78
|
+
obs_re, states_re = self.reset(key_reset)
|
|
79
|
+
else:
|
|
80
|
+
states_re = reset_state
|
|
81
|
+
obs_re = self.get_obs(states_re)
|
|
82
|
+
|
|
83
|
+
# Auto-reset environment based on termination
|
|
84
|
+
states = jax.tree.map(
|
|
85
|
+
lambda x, y: jax.lax.select(dones["__all__"], x, y), states_re, states_st
|
|
86
|
+
)
|
|
87
|
+
obs = jax.tree.map(
|
|
88
|
+
lambda x, y: jax.lax.select(dones["__all__"], x, y), obs_re, obs_st
|
|
89
|
+
)
|
|
90
|
+
return obs, states, rewards, dones, infos
|
|
91
|
+
|
|
92
|
+
def step_env(
|
|
93
|
+
self, key: chex.PRNGKey, state: State, actions: Dict[str, chex.Array]
|
|
94
|
+
) -> Tuple[Dict[str, chex.Array], State, Dict[str, float], Dict[str, bool], Dict]:
|
|
95
|
+
"""Environment-specific step transition.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
key (chex.PRNGKey): random key
|
|
99
|
+
state (State): environment state
|
|
100
|
+
actions (Dict[str, chex.Array]): agent actions, keyed by agent name
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
Observations (Dict[str, chex.Array]): next observations
|
|
104
|
+
State (State): next environment state
|
|
105
|
+
Rewards (Dict[str, float]): rewards, keyed by agent name
|
|
106
|
+
Dones (Dict[str, bool]): dones, keyed by agent name:
|
|
107
|
+
Info (Dict): info dictionary
|
|
108
|
+
"""
|
|
109
|
+
|
|
110
|
+
raise NotImplementedError
|
|
111
|
+
|
|
112
|
+
def get_obs(self, state: State) -> Dict[str, chex.Array]:
|
|
113
|
+
"""Applies observation function to state.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
State (state): Environment state
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
Observations (Dict[str, chex.Array]): observations keyed by agent names"""
|
|
120
|
+
raise NotImplementedError
|
|
121
|
+
|
|
122
|
+
def observation_space(self, agent: str) -> Space:
|
|
123
|
+
"""Observation space for a given agent.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
agent (str): agent name
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
space (Space): observation space
|
|
130
|
+
"""
|
|
131
|
+
return self.observation_spaces[agent]
|
|
132
|
+
|
|
133
|
+
def action_space(self, agent: str) -> Space:
|
|
134
|
+
"""Action space for a given agent.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
agent (str): agent name
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
space (Space): action space
|
|
141
|
+
"""
|
|
142
|
+
return self.action_spaces[agent]
|
|
143
|
+
|
|
144
|
+
@partial(jax.jit, static_argnums=(0,))
|
|
145
|
+
def get_avail_actions(self, state: State) -> Dict[str, chex.Array]:
|
|
146
|
+
"""Returns the available actions for each agent.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
state (State): environment state
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
available actions (Dict[str, chex.Array]): available actions keyed by agent name
|
|
153
|
+
"""
|
|
154
|
+
raise NotImplementedError
|
|
155
|
+
|
|
156
|
+
@property
|
|
157
|
+
def name(self) -> str:
|
|
158
|
+
"""Environment name."""
|
|
159
|
+
return type(self).__name__
|
|
160
|
+
|
|
161
|
+
@property
|
|
162
|
+
def agent_classes(self) -> dict:
|
|
163
|
+
"""Returns a dictionary with agent classes
|
|
164
|
+
|
|
165
|
+
Format:
|
|
166
|
+
agent_names: [agent_base_name_1, agent_base_name_2, ...]
|
|
167
|
+
"""
|
|
168
|
+
raise NotImplementedError
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|