dfa-gym 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dfa_gym/__init__.py +5 -15
- dfa_gym/dfa_bisim_env.py +121 -0
- dfa_gym/dfa_wrapper.py +185 -52
- dfa_gym/env.py +168 -0
- dfa_gym/maps/2buttons_2agents.pdf +0 -0
- dfa_gym/maps/2rooms_2agents.pdf +0 -0
- dfa_gym/maps/4buttons_4agents.pdf +0 -0
- dfa_gym/maps/4rooms_4agents.pdf +0 -0
- dfa_gym/robot.png +0 -0
- dfa_gym/spaces.py +156 -0
- dfa_gym/token_env.py +571 -0
- dfa_gym/utils.py +266 -0
- dfa_gym-0.2.0.dist-info/METADATA +93 -0
- dfa_gym-0.2.0.dist-info/RECORD +16 -0
- {dfa_gym-0.1.0.dist-info → dfa_gym-0.2.0.dist-info}/WHEEL +1 -1
- dfa_gym/dfa_env.py +0 -45
- dfa_gym-0.1.0.dist-info/METADATA +0 -11
- dfa_gym-0.1.0.dist-info/RECORD +0 -7
- {dfa_gym-0.1.0.dist-info → dfa_gym-0.2.0.dist-info}/licenses/LICENSE +0 -0
dfa_gym/spaces.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
1
|
+
""" Built off Gymnax spaces.py, this module contains jittable classes for action and observation spaces. """
|
|
2
|
+
from typing import Tuple, Union, Sequence
|
|
3
|
+
from collections import OrderedDict
|
|
4
|
+
import chex
|
|
5
|
+
import jax
|
|
6
|
+
import jax.numpy as jnp
|
|
7
|
+
|
|
8
|
+
class Space(object):
|
|
9
|
+
"""
|
|
10
|
+
Minimal jittable class for abstract jaxmarl space.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
def sample(self, rng: chex.PRNGKey) -> chex.Array:
|
|
14
|
+
raise NotImplementedError
|
|
15
|
+
|
|
16
|
+
def contains(self, x: jnp.int_) -> bool:
|
|
17
|
+
raise NotImplementedError
|
|
18
|
+
|
|
19
|
+
class Discrete(Space):
|
|
20
|
+
"""
|
|
21
|
+
Minimal jittable class for discrete gymnax spaces.
|
|
22
|
+
TODO: For now this is a 1d space. Make composable for multi-discrete.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(self, num_categories: int, dtype=jnp.int32):
|
|
26
|
+
assert num_categories >= 0
|
|
27
|
+
self.n = num_categories
|
|
28
|
+
self.shape = ()
|
|
29
|
+
self.dtype = dtype
|
|
30
|
+
|
|
31
|
+
def sample(self, rng: chex.PRNGKey) -> chex.Array:
|
|
32
|
+
"""Sample random action uniformly from set of categorical choices."""
|
|
33
|
+
return jax.random.randint(
|
|
34
|
+
rng, shape=self.shape, minval=0, maxval=self.n
|
|
35
|
+
).astype(self.dtype)
|
|
36
|
+
|
|
37
|
+
def contains(self, x: jnp.int_) -> bool:
|
|
38
|
+
"""Check whether specific object is within space."""
|
|
39
|
+
# type_cond = isinstance(x, self.dtype)
|
|
40
|
+
# shape_cond = (x.shape == self.shape)
|
|
41
|
+
range_cond = jnp.logical_and(x >= 0, x < self.n)
|
|
42
|
+
return range_cond
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class MultiDiscrete(Space):
|
|
46
|
+
"""
|
|
47
|
+
Minimal jittable class for multi-discrete gymnax spaces.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
def __init__(self, num_categories: Sequence[int]):
|
|
51
|
+
"""Num categories is the number of cat actions for each dim, [2,2,2]=2 actions x 3 dim"""
|
|
52
|
+
self.num_categories = jnp.array(num_categories)
|
|
53
|
+
self.shape = (len(num_categories),)
|
|
54
|
+
self.dtype = jnp.int_
|
|
55
|
+
|
|
56
|
+
def sample(self, rng: chex.PRNGKey) -> chex.Array:
|
|
57
|
+
"""Sample random action uniformly from set of categorical choices."""
|
|
58
|
+
return jax.random.randint(
|
|
59
|
+
rng,
|
|
60
|
+
shape=self.shape,
|
|
61
|
+
minval=0,
|
|
62
|
+
maxval=self.num_categories,
|
|
63
|
+
dtype=self.dtype
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
def contains(self, x: jnp.int_) -> bool:
|
|
67
|
+
"""Check whether specific object is within space."""
|
|
68
|
+
range_cond = jnp.logical_and(x >= 0, x < self.num_categories)
|
|
69
|
+
return jnp.all(range_cond)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class Box(Space):
|
|
73
|
+
"""
|
|
74
|
+
Minimal jittable class for array-shaped gymnax spaces.
|
|
75
|
+
TODO: Add unboundedness - sampling from other distributions, etc.
|
|
76
|
+
"""
|
|
77
|
+
def __init__(
|
|
78
|
+
self,
|
|
79
|
+
low: float,
|
|
80
|
+
high: float,
|
|
81
|
+
shape: Tuple[int],
|
|
82
|
+
dtype: jnp.dtype = jnp.float32,
|
|
83
|
+
):
|
|
84
|
+
self.low = low
|
|
85
|
+
self.high = high
|
|
86
|
+
self.shape = shape
|
|
87
|
+
self.dtype = dtype
|
|
88
|
+
|
|
89
|
+
def sample(self, rng: chex.PRNGKey) -> chex.Array:
|
|
90
|
+
"""Sample random action uniformly from 1D continuous range."""
|
|
91
|
+
return jax.random.uniform(
|
|
92
|
+
rng, shape=self.shape, minval=self.low, maxval=self.high
|
|
93
|
+
).astype(self.dtype)
|
|
94
|
+
|
|
95
|
+
def contains(self, x: jnp.int_) -> bool:
|
|
96
|
+
"""Check whether specific object is within space."""
|
|
97
|
+
# type_cond = isinstance(x, self.dtype)
|
|
98
|
+
# shape_cond = (x.shape == self.shape)
|
|
99
|
+
range_cond = jnp.logical_and(
|
|
100
|
+
jnp.all(x >= self.low), jnp.all(x <= self.high)
|
|
101
|
+
)
|
|
102
|
+
return range_cond
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class Dict(Space):
|
|
106
|
+
"""Minimal jittable class for dictionary of simpler jittable spaces."""
|
|
107
|
+
def __init__(self, spaces: dict):
|
|
108
|
+
self.spaces = spaces
|
|
109
|
+
self.num_spaces = len(spaces)
|
|
110
|
+
|
|
111
|
+
def sample(self, rng: chex.PRNGKey) -> dict:
|
|
112
|
+
"""Sample random action from all subspaces."""
|
|
113
|
+
key_split = jax.random.split(rng, self.num_spaces)
|
|
114
|
+
return OrderedDict(
|
|
115
|
+
[
|
|
116
|
+
(k, self.spaces[k].sample(key_split[i]))
|
|
117
|
+
for i, k in enumerate(self.spaces)
|
|
118
|
+
]
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
def contains(self, x: jnp.int_) -> bool:
|
|
122
|
+
"""Check whether dimensions of object are within subspace."""
|
|
123
|
+
# type_cond = isinstance(x, dict)
|
|
124
|
+
# num_space_cond = len(x) != len(self.spaces)
|
|
125
|
+
# Check for each space individually
|
|
126
|
+
out_of_space = 0
|
|
127
|
+
for k, space in self.spaces.items():
|
|
128
|
+
out_of_space += 1 - space.contains(getattr(x, k))
|
|
129
|
+
return out_of_space == 0
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
class Tuple(Space):
|
|
133
|
+
"""Minimal jittable class for tuple (product) of jittable spaces."""
|
|
134
|
+
def __init__(self, spaces: Union[tuple, list]):
|
|
135
|
+
self.spaces = spaces
|
|
136
|
+
self.num_spaces = len(spaces)
|
|
137
|
+
|
|
138
|
+
def sample(self, rng: chex.PRNGKey) -> Tuple[chex.Array]:
|
|
139
|
+
"""Sample random action from all subspaces."""
|
|
140
|
+
key_split = jax.random.split(rng, self.num_spaces)
|
|
141
|
+
return tuple(
|
|
142
|
+
[
|
|
143
|
+
space.sample(key_split[i])
|
|
144
|
+
for i, space in enumerate(self.spaces)
|
|
145
|
+
]
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
def contains(self, x: jnp.int_) -> bool:
|
|
149
|
+
"""Check whether dimensions of object are within subspace."""
|
|
150
|
+
# type_cond = isinstance(x, tuple)
|
|
151
|
+
# num_space_cond = len(x) != len(self.spaces)
|
|
152
|
+
# Check for each space individually
|
|
153
|
+
out_of_space = 0
|
|
154
|
+
for space in self.spaces:
|
|
155
|
+
out_of_space += 1 - space.contains(x)
|
|
156
|
+
return out_of_space == 0
|