dfa-gym 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dfa_gym/spaces.py ADDED
@@ -0,0 +1,156 @@
1
+ """ Built off Gymnax spaces.py, this module contains jittable classes for action and observation spaces. """
2
+ from typing import Tuple, Union, Sequence
3
+ from collections import OrderedDict
4
+ import chex
5
+ import jax
6
+ import jax.numpy as jnp
7
+
8
+ class Space(object):
9
+ """
10
+ Minimal jittable class for abstract jaxmarl space.
11
+ """
12
+
13
+ def sample(self, rng: chex.PRNGKey) -> chex.Array:
14
+ raise NotImplementedError
15
+
16
+ def contains(self, x: jnp.int_) -> bool:
17
+ raise NotImplementedError
18
+
19
+ class Discrete(Space):
20
+ """
21
+ Minimal jittable class for discrete gymnax spaces.
22
+ TODO: For now this is a 1d space. Make composable for multi-discrete.
23
+ """
24
+
25
+ def __init__(self, num_categories: int, dtype=jnp.int32):
26
+ assert num_categories >= 0
27
+ self.n = num_categories
28
+ self.shape = ()
29
+ self.dtype = dtype
30
+
31
+ def sample(self, rng: chex.PRNGKey) -> chex.Array:
32
+ """Sample random action uniformly from set of categorical choices."""
33
+ return jax.random.randint(
34
+ rng, shape=self.shape, minval=0, maxval=self.n
35
+ ).astype(self.dtype)
36
+
37
+ def contains(self, x: jnp.int_) -> bool:
38
+ """Check whether specific object is within space."""
39
+ # type_cond = isinstance(x, self.dtype)
40
+ # shape_cond = (x.shape == self.shape)
41
+ range_cond = jnp.logical_and(x >= 0, x < self.n)
42
+ return range_cond
43
+
44
+
45
+ class MultiDiscrete(Space):
46
+ """
47
+ Minimal jittable class for multi-discrete gymnax spaces.
48
+ """
49
+
50
+ def __init__(self, num_categories: Sequence[int]):
51
+ """Num categories is the number of cat actions for each dim, [2,2,2]=2 actions x 3 dim"""
52
+ self.num_categories = jnp.array(num_categories)
53
+ self.shape = (len(num_categories),)
54
+ self.dtype = jnp.int_
55
+
56
+ def sample(self, rng: chex.PRNGKey) -> chex.Array:
57
+ """Sample random action uniformly from set of categorical choices."""
58
+ return jax.random.randint(
59
+ rng,
60
+ shape=self.shape,
61
+ minval=0,
62
+ maxval=self.num_categories,
63
+ dtype=self.dtype
64
+ )
65
+
66
+ def contains(self, x: jnp.int_) -> bool:
67
+ """Check whether specific object is within space."""
68
+ range_cond = jnp.logical_and(x >= 0, x < self.num_categories)
69
+ return jnp.all(range_cond)
70
+
71
+
72
+ class Box(Space):
73
+ """
74
+ Minimal jittable class for array-shaped gymnax spaces.
75
+ TODO: Add unboundedness - sampling from other distributions, etc.
76
+ """
77
+ def __init__(
78
+ self,
79
+ low: float,
80
+ high: float,
81
+ shape: Tuple[int],
82
+ dtype: jnp.dtype = jnp.float32,
83
+ ):
84
+ self.low = low
85
+ self.high = high
86
+ self.shape = shape
87
+ self.dtype = dtype
88
+
89
+ def sample(self, rng: chex.PRNGKey) -> chex.Array:
90
+ """Sample random action uniformly from 1D continuous range."""
91
+ return jax.random.uniform(
92
+ rng, shape=self.shape, minval=self.low, maxval=self.high
93
+ ).astype(self.dtype)
94
+
95
+ def contains(self, x: jnp.int_) -> bool:
96
+ """Check whether specific object is within space."""
97
+ # type_cond = isinstance(x, self.dtype)
98
+ # shape_cond = (x.shape == self.shape)
99
+ range_cond = jnp.logical_and(
100
+ jnp.all(x >= self.low), jnp.all(x <= self.high)
101
+ )
102
+ return range_cond
103
+
104
+
105
+ class Dict(Space):
106
+ """Minimal jittable class for dictionary of simpler jittable spaces."""
107
+ def __init__(self, spaces: dict):
108
+ self.spaces = spaces
109
+ self.num_spaces = len(spaces)
110
+
111
+ def sample(self, rng: chex.PRNGKey) -> dict:
112
+ """Sample random action from all subspaces."""
113
+ key_split = jax.random.split(rng, self.num_spaces)
114
+ return OrderedDict(
115
+ [
116
+ (k, self.spaces[k].sample(key_split[i]))
117
+ for i, k in enumerate(self.spaces)
118
+ ]
119
+ )
120
+
121
+ def contains(self, x: jnp.int_) -> bool:
122
+ """Check whether dimensions of object are within subspace."""
123
+ # type_cond = isinstance(x, dict)
124
+ # num_space_cond = len(x) != len(self.spaces)
125
+ # Check for each space individually
126
+ out_of_space = 0
127
+ for k, space in self.spaces.items():
128
+ out_of_space += 1 - space.contains(getattr(x, k))
129
+ return out_of_space == 0
130
+
131
+
132
+ class Tuple(Space):
133
+ """Minimal jittable class for tuple (product) of jittable spaces."""
134
+ def __init__(self, spaces: Union[tuple, list]):
135
+ self.spaces = spaces
136
+ self.num_spaces = len(spaces)
137
+
138
+ def sample(self, rng: chex.PRNGKey) -> Tuple[chex.Array]:
139
+ """Sample random action from all subspaces."""
140
+ key_split = jax.random.split(rng, self.num_spaces)
141
+ return tuple(
142
+ [
143
+ space.sample(key_split[i])
144
+ for i, space in enumerate(self.spaces)
145
+ ]
146
+ )
147
+
148
+ def contains(self, x: jnp.int_) -> bool:
149
+ """Check whether dimensions of object are within subspace."""
150
+ # type_cond = isinstance(x, tuple)
151
+ # num_space_cond = len(x) != len(self.spaces)
152
+ # Check for each space individually
153
+ out_of_space = 0
154
+ for space in self.spaces:
155
+ out_of_space += 1 - space.contains(x)
156
+ return out_of_space == 0