pufferlib-core 2.0.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pufferlib_core-2.0.0/PKG-INFO +41 -0
- pufferlib_core-2.0.0/README.md +10 -0
- pufferlib_core-2.0.0/pufferlib/__init__.py +10 -0
- pufferlib_core-2.0.0/pufferlib/emulation.py +518 -0
- pufferlib_core-2.0.0/pufferlib/spaces.py +25 -0
- pufferlib_core-2.0.0/pufferlib/vector.py +925 -0
- pufferlib_core-2.0.0/pufferlib_core.egg-info/PKG-INFO +41 -0
- pufferlib_core-2.0.0/pufferlib_core.egg-info/SOURCES.txt +11 -0
- pufferlib_core-2.0.0/pufferlib_core.egg-info/dependency_links.txt +1 -0
- pufferlib_core-2.0.0/pufferlib_core.egg-info/requires.txt +2 -0
- pufferlib_core-2.0.0/pufferlib_core.egg-info/top_level.txt +1 -0
- pufferlib_core-2.0.0/pyproject.toml +53 -0
- pufferlib_core-2.0.0/setup.cfg +4 -0
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: pufferlib-core
|
|
3
|
+
Version: 2.0.0
|
|
4
|
+
Summary: Minimal PufferLib core functionality with vectorized environments
|
|
5
|
+
Author-email: Joseph Suarez <jsuarez@openai.com>
|
|
6
|
+
License: MIT
|
|
7
|
+
Project-URL: Homepage, https://github.com/PufferAI/PufferLib
|
|
8
|
+
Project-URL: Documentation, https://puffer.ai
|
|
9
|
+
Project-URL: Repository, https://github.com/PufferAI/PufferLib
|
|
10
|
+
Project-URL: Issues, https://github.com/PufferAI/PufferLib/issues
|
|
11
|
+
Keywords: reinforcement-learning,machine-learning,multi-agent,vectorized-environments
|
|
12
|
+
Classifier: Development Status :: 5 - Production/Stable
|
|
13
|
+
Classifier: Environment :: Console
|
|
14
|
+
Classifier: Intended Audience :: Developers
|
|
15
|
+
Classifier: Intended Audience :: Education
|
|
16
|
+
Classifier: Intended Audience :: Science/Research
|
|
17
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
18
|
+
Classifier: Operating System :: OS Independent
|
|
19
|
+
Classifier: Programming Language :: Python :: 3
|
|
20
|
+
Classifier: Programming Language :: Python :: 3.8
|
|
21
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
22
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
23
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
24
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
25
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
26
|
+
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
27
|
+
Requires-Python: >=3.8
|
|
28
|
+
Description-Content-Type: text/markdown
|
|
29
|
+
Requires-Dist: numpy
|
|
30
|
+
Requires-Dist: gymnasium
|
|
31
|
+
|
|
32
|
+
# PufferLib Core
|
|
33
|
+
|
|
34
|
+
Minimal PufferLib core functionality with vectorized environments.
|
|
35
|
+
|
|
36
|
+
This package contains only the essential components:
|
|
37
|
+
- `spaces`: Observation/action space handling
|
|
38
|
+
- `emulation`: Environment compatibility layer for Gym/Gymnasium/PettingZoo
|
|
39
|
+
- `vector`: Vectorized environment implementations
|
|
40
|
+
|
|
41
|
+
For the full PufferLib with training capabilities and environments, see the main `pufferlib` package.
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
# PufferLib Core
|
|
2
|
+
|
|
3
|
+
Minimal PufferLib core functionality with vectorized environments.
|
|
4
|
+
|
|
5
|
+
This package contains only the essential components:
|
|
6
|
+
- `spaces`: Observation/action space handling
|
|
7
|
+
- `emulation`: Environment compatibility layer for Gym/Gymnasium/PettingZoo
|
|
8
|
+
- `vector`: Vectorized environment implementations
|
|
9
|
+
|
|
10
|
+
For the full PufferLib with training capabilities and environments, see the main `pufferlib` package.
|
|
@@ -0,0 +1,518 @@
|
|
|
1
|
+
from pdb import set_trace as T
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import warnings
|
|
5
|
+
|
|
6
|
+
import gymnasium
|
|
7
|
+
import inspect
|
|
8
|
+
|
|
9
|
+
import pufferlib
|
|
10
|
+
import pufferlib.spaces
|
|
11
|
+
from pufferlib.spaces import Discrete, Tuple, Dict
|
|
12
|
+
|
|
13
|
+
def emulate(struct, sample):
|
|
14
|
+
if isinstance(sample, dict):
|
|
15
|
+
for k, v in sample.items():
|
|
16
|
+
emulate(struct[k], v)
|
|
17
|
+
elif isinstance(sample, tuple):
|
|
18
|
+
for i, v in enumerate(sample):
|
|
19
|
+
emulate(struct[f'f{i}'], v)
|
|
20
|
+
else:
|
|
21
|
+
struct[()] = sample
|
|
22
|
+
|
|
23
|
+
def make_buffer(arr_dtype, struct_dtype, struct, n=None):
|
|
24
|
+
'''None instead of 1 makes it work for 1 agent PZ envs'''
|
|
25
|
+
'''
|
|
26
|
+
if n is None:
|
|
27
|
+
struct = np.zeros(1, dtype=struct_dtype)
|
|
28
|
+
else:
|
|
29
|
+
struct = np.zeros(n, dtype=struct_dtype)
|
|
30
|
+
'''
|
|
31
|
+
|
|
32
|
+
arr = struct.view(arr_dtype)
|
|
33
|
+
|
|
34
|
+
if n is None:
|
|
35
|
+
arr = arr.ravel()
|
|
36
|
+
else:
|
|
37
|
+
arr = arr.reshape(n, -1)
|
|
38
|
+
|
|
39
|
+
return arr
|
|
40
|
+
|
|
41
|
+
def _nativize(struct, space):
|
|
42
|
+
if isinstance(space, Discrete):
|
|
43
|
+
return struct.item()
|
|
44
|
+
elif isinstance(space, Tuple):
|
|
45
|
+
return tuple(_nativize(struct[f'f{i}'], elem)
|
|
46
|
+
for i, elem in enumerate(space))
|
|
47
|
+
elif isinstance(space, Dict):
|
|
48
|
+
return {k: _nativize(struct[k], value)
|
|
49
|
+
for k, value in space.items()}
|
|
50
|
+
else:
|
|
51
|
+
return struct
|
|
52
|
+
|
|
53
|
+
def nativize(arr, space, struct_dtype):
|
|
54
|
+
struct = np.asarray(arr).view(struct_dtype)[0]
|
|
55
|
+
return _nativize(struct, space)
|
|
56
|
+
|
|
57
|
+
# TODO: Uncomment?
|
|
58
|
+
'''
|
|
59
|
+
try:
|
|
60
|
+
from pufferlib.extensions import emulate, nativize
|
|
61
|
+
except ImportError:
|
|
62
|
+
warnings.warn('PufferLib Cython extensions not installed. Using slow Python versions')
|
|
63
|
+
'''
|
|
64
|
+
|
|
65
|
+
def get_dtype_bounds(dtype):
|
|
66
|
+
if dtype == bool:
|
|
67
|
+
return 0, 1
|
|
68
|
+
elif np.issubdtype(dtype, np.integer):
|
|
69
|
+
return np.iinfo(dtype).min, np.iinfo(dtype).max
|
|
70
|
+
elif np.issubdtype(dtype, np.unsignedinteger):
|
|
71
|
+
return np.iinfo(dtype).min, np.iinfo(dtype).max
|
|
72
|
+
elif np.issubdtype(dtype, np.floating):
|
|
73
|
+
# Gym fails on float64
|
|
74
|
+
return np.finfo(np.float32).min, np.finfo(np.float32).max
|
|
75
|
+
else:
|
|
76
|
+
raise ValueError(f"Unsupported dtype: {dtype}")
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def dtype_from_space(space):
|
|
80
|
+
if isinstance(space, pufferlib.spaces.Tuple):
|
|
81
|
+
dtype = []
|
|
82
|
+
for i, elem in enumerate(space):
|
|
83
|
+
dtype.append((f'f{i}', dtype_from_space(elem)))
|
|
84
|
+
elif isinstance(space, pufferlib.spaces.Dict):
|
|
85
|
+
dtype = []
|
|
86
|
+
for k, value in space.items():
|
|
87
|
+
dtype.append((k, dtype_from_space(value)))
|
|
88
|
+
elif isinstance(space, (pufferlib.spaces.Discrete)):
|
|
89
|
+
dtype = (np.int32, ())
|
|
90
|
+
elif isinstance(space, (pufferlib.spaces.MultiDiscrete)):
|
|
91
|
+
dtype = (np.int32, (len(space.nvec),))
|
|
92
|
+
else:
|
|
93
|
+
dtype = (space.dtype, space.shape)
|
|
94
|
+
|
|
95
|
+
return np.dtype(dtype, align=True)
|
|
96
|
+
|
|
97
|
+
def flatten_space(space):
|
|
98
|
+
if isinstance(space, pufferlib.spaces.Tuple):
|
|
99
|
+
subspaces = []
|
|
100
|
+
for e in space:
|
|
101
|
+
subspaces.extend(flatten_space(e))
|
|
102
|
+
return subspaces
|
|
103
|
+
elif isinstance(space, pufferlib.spaces.Dict):
|
|
104
|
+
subspaces = []
|
|
105
|
+
for e in space.values():
|
|
106
|
+
subspaces.extend(flatten_space(e))
|
|
107
|
+
return subspaces
|
|
108
|
+
else:
|
|
109
|
+
return [space]
|
|
110
|
+
|
|
111
|
+
def emulate_observation_space(space):
|
|
112
|
+
emulated_dtype = dtype_from_space(space)
|
|
113
|
+
|
|
114
|
+
if isinstance(space, pufferlib.spaces.Box):
|
|
115
|
+
return space, emulated_dtype
|
|
116
|
+
|
|
117
|
+
leaves = flatten_space(space)
|
|
118
|
+
dtypes = [e.dtype for e in leaves]
|
|
119
|
+
if dtypes.count(dtypes[0]) == len(dtypes):
|
|
120
|
+
dtype = dtypes[0]
|
|
121
|
+
else:
|
|
122
|
+
dtype = np.dtype(np.uint8)
|
|
123
|
+
|
|
124
|
+
mmin, mmax = get_dtype_bounds(dtype)
|
|
125
|
+
numel = emulated_dtype.itemsize // dtype.itemsize
|
|
126
|
+
emulated_space = gymnasium.spaces.Box(low=mmin, high=mmax, shape=(numel,), dtype=dtype)
|
|
127
|
+
return emulated_space, emulated_dtype
|
|
128
|
+
|
|
129
|
+
def emulate_action_space(space):
|
|
130
|
+
if isinstance(space, pufferlib.spaces.Box):
|
|
131
|
+
return space, space.dtype
|
|
132
|
+
elif isinstance(space, (pufferlib.spaces.Discrete, pufferlib.spaces.MultiDiscrete)):
|
|
133
|
+
return space, np.int32
|
|
134
|
+
|
|
135
|
+
emulated_dtype = dtype_from_space(space)
|
|
136
|
+
leaves = flatten_space(space)
|
|
137
|
+
emulated_space = gymnasium.spaces.MultiDiscrete([e.n for e in leaves])
|
|
138
|
+
return emulated_space, emulated_dtype
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
class GymnasiumPufferEnv(gymnasium.Env):
|
|
142
|
+
def __init__(self, env=None, env_creator=None, env_args=[], env_kwargs={}, buf=None, seed=0):
|
|
143
|
+
self.env = make_object(env, env_creator, env_args, env_kwargs)
|
|
144
|
+
|
|
145
|
+
self.initialized = False
|
|
146
|
+
self.done = True
|
|
147
|
+
|
|
148
|
+
self.is_observation_checked = False
|
|
149
|
+
self.is_action_checked = False
|
|
150
|
+
|
|
151
|
+
self.observation_space, self.obs_dtype = emulate_observation_space(
|
|
152
|
+
self.env.observation_space)
|
|
153
|
+
self.action_space, self.atn_dtype = emulate_action_space(
|
|
154
|
+
self.env.action_space)
|
|
155
|
+
self.single_observation_space = self.observation_space
|
|
156
|
+
self.single_action_space = self.action_space
|
|
157
|
+
self.num_agents = 1
|
|
158
|
+
|
|
159
|
+
self.is_obs_emulated = self.single_observation_space is not self.env.observation_space
|
|
160
|
+
self.is_atn_emulated = self.single_action_space is not self.env.action_space
|
|
161
|
+
self.emulated = dict(
|
|
162
|
+
observation_dtype=self.observation_space.dtype,
|
|
163
|
+
emulated_observation_dtype=self.obs_dtype,
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
self.render_modes = 'human rgb_array'.split()
|
|
167
|
+
|
|
168
|
+
pufferlib.set_buffers(self, buf)
|
|
169
|
+
if isinstance(self.env.observation_space, pufferlib.spaces.Box):
|
|
170
|
+
self.obs_struct = self.observations
|
|
171
|
+
else:
|
|
172
|
+
self.obs_struct = self.observations.view(self.obs_dtype)
|
|
173
|
+
|
|
174
|
+
@property
|
|
175
|
+
def render_mode(self):
|
|
176
|
+
return self.env.render_mode
|
|
177
|
+
|
|
178
|
+
def seed(self, seed):
|
|
179
|
+
self.env.seed(seed)
|
|
180
|
+
|
|
181
|
+
def reset(self, seed=None):
|
|
182
|
+
self.initialized = True
|
|
183
|
+
self.done = False
|
|
184
|
+
|
|
185
|
+
ob, info = _seed_and_reset(self.env, seed)
|
|
186
|
+
if not self.is_observation_checked:
|
|
187
|
+
self.is_observation_checked = check_space(
|
|
188
|
+
ob, self.env.observation_space)
|
|
189
|
+
|
|
190
|
+
if self.is_obs_emulated:
|
|
191
|
+
emulate(self.obs_struct, ob)
|
|
192
|
+
else:
|
|
193
|
+
self.observations[:] = ob
|
|
194
|
+
|
|
195
|
+
self.rewards[0] = 0
|
|
196
|
+
self.terminals[0] = False
|
|
197
|
+
self.truncations[0] = False
|
|
198
|
+
self.masks[0] = True
|
|
199
|
+
|
|
200
|
+
return self.observations, info
|
|
201
|
+
|
|
202
|
+
def step(self, action):
|
|
203
|
+
'''Execute an action and return (observation, reward, done, info)'''
|
|
204
|
+
if not self.initialized:
|
|
205
|
+
raise pufferlib.APIUsageError('step() called before reset()')
|
|
206
|
+
if self.done:
|
|
207
|
+
raise pufferlib.APIUsageError('step() called after environment is done')
|
|
208
|
+
|
|
209
|
+
# Unpack actions from multidiscrete into the original action space
|
|
210
|
+
if self.is_atn_emulated:
|
|
211
|
+
action = nativize(action, self.env.action_space, self.atn_dtype)
|
|
212
|
+
elif isinstance(action, np.ndarray):
|
|
213
|
+
action = action.ravel()
|
|
214
|
+
# TODO: profile or speed up
|
|
215
|
+
if isinstance(self.action_space, pufferlib.spaces.Discrete):
|
|
216
|
+
action = action[0]
|
|
217
|
+
|
|
218
|
+
if not self.is_action_checked:
|
|
219
|
+
self.is_action_checked = check_space(
|
|
220
|
+
action, self.env.action_space)
|
|
221
|
+
|
|
222
|
+
ob, reward, done, truncated, info = self.env.step(action)
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
if self.is_obs_emulated:
|
|
226
|
+
emulate(self.obs_struct, ob)
|
|
227
|
+
else:
|
|
228
|
+
self.observations[:] = ob
|
|
229
|
+
|
|
230
|
+
self.rewards[0] = reward
|
|
231
|
+
self.terminals[0] = done
|
|
232
|
+
self.truncations[0] = truncated
|
|
233
|
+
self.masks[0] = True
|
|
234
|
+
|
|
235
|
+
self.done = done or truncated
|
|
236
|
+
return self.observations, reward, done, truncated, info
|
|
237
|
+
|
|
238
|
+
def render(self):
|
|
239
|
+
return self.env.render()
|
|
240
|
+
|
|
241
|
+
def close(self):
|
|
242
|
+
return self.env.close()
|
|
243
|
+
|
|
244
|
+
class PettingZooPufferEnv:
|
|
245
|
+
def __init__(self, env=None, env_creator=None, env_args=[], env_kwargs={}, buf=None, seed=0):
|
|
246
|
+
self.env = make_object(env, env_creator, env_args, env_kwargs)
|
|
247
|
+
self.initialized = False
|
|
248
|
+
self.all_done = True
|
|
249
|
+
|
|
250
|
+
self.is_observation_checked = False
|
|
251
|
+
self.is_action_checked = False
|
|
252
|
+
|
|
253
|
+
# Compute the observation and action spaces
|
|
254
|
+
single_agent = self.possible_agents[0]
|
|
255
|
+
self.env_single_observation_space = self.env.observation_space(single_agent)
|
|
256
|
+
self.env_single_action_space = self.env.action_space(single_agent)
|
|
257
|
+
self.single_observation_space, self.obs_dtype = (
|
|
258
|
+
emulate_observation_space(self.env_single_observation_space))
|
|
259
|
+
self.single_action_space, self.atn_dtype = (
|
|
260
|
+
emulate_action_space(self.env_single_action_space))
|
|
261
|
+
self.is_obs_emulated = self.single_observation_space is not self.env_single_observation_space
|
|
262
|
+
self.is_atn_emulated = self.single_action_space is not self.env_single_action_space
|
|
263
|
+
self.emulated = dict(
|
|
264
|
+
observation_dtype = self.single_observation_space.dtype,
|
|
265
|
+
emulated_observation_dtype = self.obs_dtype,
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
self.num_agents = len(self.possible_agents)
|
|
269
|
+
|
|
270
|
+
pufferlib.set_buffers(self, buf)
|
|
271
|
+
if isinstance(self.env_single_observation_space, pufferlib.spaces.Box):
|
|
272
|
+
self.obs_struct = self.observations
|
|
273
|
+
else:
|
|
274
|
+
self.obs_struct = self.observations.view(self.obs_dtype)
|
|
275
|
+
|
|
276
|
+
@property
|
|
277
|
+
def render_mode(self):
|
|
278
|
+
return self.env.render_mode
|
|
279
|
+
|
|
280
|
+
@property
|
|
281
|
+
def agents(self):
|
|
282
|
+
return self.env.agents
|
|
283
|
+
|
|
284
|
+
@property
|
|
285
|
+
def possible_agents(self):
|
|
286
|
+
return self.env.possible_agents
|
|
287
|
+
|
|
288
|
+
@property
|
|
289
|
+
def done(self):
|
|
290
|
+
return len(self.agents) == 0 or self.all_done
|
|
291
|
+
|
|
292
|
+
def observation_space(self, agent):
|
|
293
|
+
'''Returns the observation space for a single agent'''
|
|
294
|
+
if agent not in self.possible_agents:
|
|
295
|
+
raise pufferlib.InvalidAgentError(agent, self.possible_agents)
|
|
296
|
+
|
|
297
|
+
return self.single_observation_space
|
|
298
|
+
|
|
299
|
+
def action_space(self, agent):
|
|
300
|
+
'''Returns the action space for a single agent'''
|
|
301
|
+
if agent not in self.possible_agents:
|
|
302
|
+
raise pufferlib.InvalidAgentError(agent, self.possible_agents)
|
|
303
|
+
|
|
304
|
+
return self.single_action_space
|
|
305
|
+
|
|
306
|
+
def reset(self, seed=None):
|
|
307
|
+
if not self.initialized:
|
|
308
|
+
self.dict_obs = {agent: self.observations[i] for i, agent in enumerate(self.possible_agents)}
|
|
309
|
+
|
|
310
|
+
self.initialized = True
|
|
311
|
+
self.all_done = False
|
|
312
|
+
self.mask = {k: False for k in self.possible_agents}
|
|
313
|
+
|
|
314
|
+
obs, info = self.env.reset(seed=seed)
|
|
315
|
+
|
|
316
|
+
if not self.is_observation_checked:
|
|
317
|
+
for k, ob in obs.items():
|
|
318
|
+
self.is_observation_checked = check_space(
|
|
319
|
+
ob, self.env.observation_space(k))
|
|
320
|
+
|
|
321
|
+
# Call user featurizer and flatten the observations
|
|
322
|
+
self.observations[:] = 0
|
|
323
|
+
for i, agent in enumerate(self.possible_agents):
|
|
324
|
+
if agent not in obs:
|
|
325
|
+
continue
|
|
326
|
+
|
|
327
|
+
ob = obs[agent]
|
|
328
|
+
self.mask[agent] = True
|
|
329
|
+
if self.is_obs_emulated:
|
|
330
|
+
emulate(self.obs_struct[i], ob)
|
|
331
|
+
else:
|
|
332
|
+
self.observations[i] = ob
|
|
333
|
+
|
|
334
|
+
self.rewards[:] = 0
|
|
335
|
+
self.terminals[:] = False
|
|
336
|
+
self.truncations[:] = False
|
|
337
|
+
self.masks[:] = True
|
|
338
|
+
return self.dict_obs, info
|
|
339
|
+
|
|
340
|
+
def step(self, actions):
|
|
341
|
+
'''Step the environment and return (observations, rewards, dones, infos)'''
|
|
342
|
+
if not self.initialized:
|
|
343
|
+
raise pufferlib.APIUsageError('step() called before reset()')
|
|
344
|
+
if self.done:
|
|
345
|
+
raise pufferlib.APIUsageError('step() called after environment is done')
|
|
346
|
+
|
|
347
|
+
if isinstance(actions, np.ndarray):
|
|
348
|
+
if not self.is_action_checked and len(actions) != self.num_agents:
|
|
349
|
+
raise pufferlib.APIUsageError(
|
|
350
|
+
f'Actions specified as len {len(actions)} but environment has {self.num_agents} agents')
|
|
351
|
+
|
|
352
|
+
actions = {agent: actions[i] for i, agent in enumerate(self.possible_agents)}
|
|
353
|
+
|
|
354
|
+
# Postprocess actions and validate action spaces
|
|
355
|
+
if not self.is_action_checked:
|
|
356
|
+
for agent in actions:
|
|
357
|
+
if agent not in self.possible_agents:
|
|
358
|
+
raise pufferlib.InvalidAgentError(agent, self.possible_agents)
|
|
359
|
+
|
|
360
|
+
self.is_action_checked = check_space(
|
|
361
|
+
next(iter(actions.values())),
|
|
362
|
+
self.single_action_space
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
# Unpack actions from multidiscrete into the original action space
|
|
366
|
+
unpacked_actions = {}
|
|
367
|
+
for agent, atn in actions.items():
|
|
368
|
+
if agent not in self.possible_agents:
|
|
369
|
+
raise pufferlib.InvalidAgentError(agent, self.agents)
|
|
370
|
+
|
|
371
|
+
if agent not in self.agents:
|
|
372
|
+
continue
|
|
373
|
+
|
|
374
|
+
if self.is_atn_emulated:
|
|
375
|
+
atn = nativize(atn, self.env_single_action_space, self.atn_dtype)
|
|
376
|
+
|
|
377
|
+
unpacked_actions[agent] = atn
|
|
378
|
+
|
|
379
|
+
obs, rewards, dones, truncateds, infos = self.env.step(unpacked_actions)
|
|
380
|
+
# TODO: Can add this assert once NMMO Horizon is ported to puffer
|
|
381
|
+
# assert all(dones.values()) == (len(self.env.agents) == 0)
|
|
382
|
+
self.mask = {k: False for k in self.possible_agents}
|
|
383
|
+
self.rewards[:] = 0
|
|
384
|
+
self.terminals[:] = True
|
|
385
|
+
self.truncations[:] = False
|
|
386
|
+
for i, agent in enumerate(self.possible_agents):
|
|
387
|
+
# TODO: negative padding buf
|
|
388
|
+
if agent not in obs:
|
|
389
|
+
self.observations[i] = 0
|
|
390
|
+
self.rewards[i] = 0
|
|
391
|
+
self.terminals[i] = True
|
|
392
|
+
self.truncations[i] = False
|
|
393
|
+
self.masks[i] = False
|
|
394
|
+
continue
|
|
395
|
+
|
|
396
|
+
ob = obs[agent]
|
|
397
|
+
self.mask[agent] = True
|
|
398
|
+
if self.is_obs_emulated:
|
|
399
|
+
emulate(self.obs_struct[i], ob)
|
|
400
|
+
else:
|
|
401
|
+
self.observations[i] = ob
|
|
402
|
+
|
|
403
|
+
self.rewards[i] = rewards[agent]
|
|
404
|
+
self.terminals[i] = dones[agent]
|
|
405
|
+
self.truncations[i] = truncateds[agent]
|
|
406
|
+
self.masks[i] = True
|
|
407
|
+
|
|
408
|
+
self.all_done = all(dones.values()) or all(truncateds.values())
|
|
409
|
+
rewards = pad_agent_data(rewards, self.possible_agents, 0)
|
|
410
|
+
dones = pad_agent_data(dones, self.possible_agents, True) # You changed this from false to match api test... is this correct?
|
|
411
|
+
truncateds = pad_agent_data(truncateds, self.possible_agents, False)
|
|
412
|
+
return self.dict_obs, rewards, dones, truncateds, infos
|
|
413
|
+
|
|
414
|
+
def render(self):
|
|
415
|
+
return self.env.render()
|
|
416
|
+
|
|
417
|
+
def close(self):
|
|
418
|
+
return self.env.close()
|
|
419
|
+
|
|
420
|
+
def pad_agent_data(data, agents, pad_value):
|
|
421
|
+
return {agent: data[agent] if agent in data else pad_value
|
|
422
|
+
for agent in agents}
|
|
423
|
+
|
|
424
|
+
def make_object(object_instance=None, object_creator=None, creator_args=[], creator_kwargs={}):
|
|
425
|
+
if (object_instance is None) == (object_creator is None):
|
|
426
|
+
raise ValueError('Exactly one of object_instance or object_creator must be provided')
|
|
427
|
+
|
|
428
|
+
if object_instance is not None:
|
|
429
|
+
if callable(object_instance) or inspect.isclass(object_instance):
|
|
430
|
+
raise TypeError('object_instance must be an instance, not a function or class')
|
|
431
|
+
return object_instance
|
|
432
|
+
|
|
433
|
+
if object_creator is not None:
|
|
434
|
+
if not callable(object_creator):
|
|
435
|
+
raise TypeError('object_creator must be a callable')
|
|
436
|
+
|
|
437
|
+
if creator_args is None:
|
|
438
|
+
creator_args = []
|
|
439
|
+
|
|
440
|
+
if creator_kwargs is None:
|
|
441
|
+
creator_kwargs = {}
|
|
442
|
+
|
|
443
|
+
return object_creator(*creator_args, **creator_kwargs)
|
|
444
|
+
|
|
445
|
+
def check_space(data, space):
|
|
446
|
+
try:
|
|
447
|
+
contains = space.contains(data)
|
|
448
|
+
except:
|
|
449
|
+
raise pufferlib.APIUsageError(
|
|
450
|
+
f'Error checking space {space} with sample :\n{data}')
|
|
451
|
+
|
|
452
|
+
if not contains:
|
|
453
|
+
raise pufferlib.APIUsageError(
|
|
454
|
+
f'Data:\n{data}\n not in space:\n{space}')
|
|
455
|
+
|
|
456
|
+
return True
|
|
457
|
+
|
|
458
|
+
def _seed_and_reset(env, seed):
|
|
459
|
+
if seed is None:
|
|
460
|
+
# Gym bug: does not reset env correctly
|
|
461
|
+
# when seed is passed as explicit None
|
|
462
|
+
return env.reset()
|
|
463
|
+
|
|
464
|
+
try:
|
|
465
|
+
obs, info = env.reset(seed=seed)
|
|
466
|
+
except:
|
|
467
|
+
try:
|
|
468
|
+
env.seed(seed)
|
|
469
|
+
obs, info = env.reset()
|
|
470
|
+
except:
|
|
471
|
+
obs, info = env.reset()
|
|
472
|
+
warnings.warn('WARNING: Environment does not support seeding.', DeprecationWarning)
|
|
473
|
+
|
|
474
|
+
return obs, info
|
|
475
|
+
|
|
476
|
+
class GymnaxPufferEnv(pufferlib.PufferEnv):
|
|
477
|
+
def __init__(self, env, env_params, num_envs=1, buf=None):
|
|
478
|
+
from gymnax.spaces import gymnax_space_to_gym_space
|
|
479
|
+
|
|
480
|
+
gymnax_obs_space = env.observation_space(env_params)
|
|
481
|
+
self.single_observation_space = gymnax_space_to_gym_space(gymnax_obs_space)
|
|
482
|
+
|
|
483
|
+
gymnax_act_space = env.action_space(env_params)
|
|
484
|
+
self.single_action_space = gymnax_space_to_gym_space(gymnax_act_space)
|
|
485
|
+
|
|
486
|
+
self.num_agents = num_envs
|
|
487
|
+
|
|
488
|
+
super().__init__(buf)
|
|
489
|
+
self.env_params = env_params
|
|
490
|
+
self.env = env
|
|
491
|
+
|
|
492
|
+
import jax
|
|
493
|
+
self.reset_fn = jax.jit(jax.vmap(env.reset, in_axes=(0, None)))
|
|
494
|
+
self.step_fn = jax.jit(jax.vmap(env.step, in_axes=(0, 0, 0, None)))
|
|
495
|
+
self.rng = jax.random.PRNGKey(0)
|
|
496
|
+
|
|
497
|
+
def reset(self, rng, params=None):
|
|
498
|
+
import jax
|
|
499
|
+
self.rng, _rng = jax.random.split(self.rng)
|
|
500
|
+
self.rngs = jax.random.split(_rng, self.num_agents)
|
|
501
|
+
obs, self.state = self.reset_fn(self.rngs, params)
|
|
502
|
+
from torch.utils import dlpack as torch_dlpack
|
|
503
|
+
self.observations = torch_dlpack.from_dlpack(jax.dlpack.to_dlpack(obs))
|
|
504
|
+
return self.observations, []
|
|
505
|
+
|
|
506
|
+
def step(self, action):
|
|
507
|
+
import jax
|
|
508
|
+
#self.rng, _rng = jax.random.split(self.rng)
|
|
509
|
+
#rngs = jax.random.split(_rng, self.num_agents)
|
|
510
|
+
obs, self.state, reward, done, info = self.step_fn(self.rngs, self.state, action, self.env_params)
|
|
511
|
+
|
|
512
|
+
# Convert JAX array to DLPack, then to PyTorch tensor
|
|
513
|
+
from torch.utils import dlpack as torch_dlpack
|
|
514
|
+
self.observations = torch_dlpack.from_dlpack(jax.dlpack.to_dlpack(obs))
|
|
515
|
+
self.rewards = np.asarray(reward)
|
|
516
|
+
self.terminals = np.asarray(done)
|
|
517
|
+
infos = [{k: v.mean().item() for k, v in info.items()}]
|
|
518
|
+
return self.observations, self.rewards, self.terminals, self.terminals, infos
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import gym
|
|
3
|
+
import gymnasium
|
|
4
|
+
|
|
5
|
+
Box = (gym.spaces.Box, gymnasium.spaces.Box)
|
|
6
|
+
Dict = (gym.spaces.Dict, gymnasium.spaces.Dict)
|
|
7
|
+
Discrete = (gym.spaces.Discrete, gymnasium.spaces.Discrete)
|
|
8
|
+
MultiBinary = (gym.spaces.MultiBinary, gymnasium.spaces.MultiBinary)
|
|
9
|
+
MultiDiscrete = (gym.spaces.MultiDiscrete, gymnasium.spaces.MultiDiscrete)
|
|
10
|
+
Tuple = (gym.spaces.Tuple, gymnasium.spaces.Tuple)
|
|
11
|
+
|
|
12
|
+
def joint_space(space, n):
|
|
13
|
+
if isinstance(space, Discrete):
|
|
14
|
+
return gymnasium.spaces.MultiDiscrete([space.n] * n)
|
|
15
|
+
elif isinstance(space, MultiDiscrete):
|
|
16
|
+
return gymnasium.spaces.Box(low=0,
|
|
17
|
+
high=np.repeat(space.nvec[None] - 1, n, axis=0),
|
|
18
|
+
shape=(n, len(space)), dtype=space.dtype)
|
|
19
|
+
elif isinstance(space, Box):
|
|
20
|
+
low = np.repeat(space.low[None], n, axis=0)
|
|
21
|
+
high = np.repeat(space.high[None], n, axis=0)
|
|
22
|
+
return gymnasium.spaces.Box(low=low, high=high,
|
|
23
|
+
shape=(n, *space.shape), dtype=space.dtype)
|
|
24
|
+
else:
|
|
25
|
+
raise ValueError(f'Unsupported space: {space}')
|