continual-foragax 0.21.0__py3-none-any.whl → 0.22.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {continual_foragax-0.21.0.dist-info → continual_foragax-0.22.0.dist-info}/METADATA +1 -1
- {continual_foragax-0.21.0.dist-info → continual_foragax-0.22.0.dist-info}/RECORD +7 -7
- foragax/env.py +63 -51
- foragax/registry.py +15 -1
- {continual_foragax-0.21.0.dist-info → continual_foragax-0.22.0.dist-info}/WHEEL +0 -0
- {continual_foragax-0.21.0.dist-info → continual_foragax-0.22.0.dist-info}/entry_points.txt +0 -0
- {continual_foragax-0.21.0.dist-info → continual_foragax-0.22.0.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,8 @@
|
|
1
1
|
foragax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
foragax/colors.py,sha256=rqNPiywP4Nvr0POhsGpasRk-nMMTS3DOwFRUgperlUk,2065
|
3
|
-
foragax/env.py,sha256=
|
3
|
+
foragax/env.py,sha256=Q-96fMoA_51TIJky2JXApITQHXC-1QfdeB5VZvNwe0o,21362
|
4
4
|
foragax/objects.py,sha256=8tBFMiquWCkhOpNndNmzovMjw7lE5P81OOlUvN2F65w,8301
|
5
|
-
foragax/registry.py,sha256=
|
5
|
+
foragax/registry.py,sha256=pRBWGP18jd4NKl1H-rwDYaAJKUgRWVfENQ9pvTS0tAw,9462
|
6
6
|
foragax/rendering.py,sha256=bms7wvBZTofoR-K-2QD2Ggeed7Viw8uwAEiEpEM3eSo,2768
|
7
7
|
foragax/weather.py,sha256=KNAiwuFz8V__6G75vZIWQKPocLzXqxXn-Vt4TbHIpcA,1258
|
8
8
|
foragax/data/ECA_non-blended_custom/TG_SOUID100897.txt,sha256=N7URbX6VlCZvCboUogYjMzy1I-0cfNPOn0QTLSHHfQ0,1776751
|
@@ -128,8 +128,8 @@ foragax/data/ECA_non-blended_custom/TG_SOUID156887.txt,sha256=juzTPgJoJxfqmZkorL
|
|
128
128
|
foragax/data/ECA_non-blended_custom/elements.txt,sha256=OtcUBoDAHxuln79BPKGu0tsQxG_5G2BfAX3Ck130kEA,4507
|
129
129
|
foragax/data/ECA_non-blended_custom/metadata.txt,sha256=nudnmOCy5cPJfSXt_IjyX0S5-T7NkCZREICZSimqeqc,48260
|
130
130
|
foragax/data/ECA_non-blended_custom/sources.txt,sha256=1j3lSmINAoCMqPqFrHfZJriOz6sTYZNOhXzUwvTLas0,20857
|
131
|
-
continual_foragax-0.
|
132
|
-
continual_foragax-0.
|
133
|
-
continual_foragax-0.
|
134
|
-
continual_foragax-0.
|
135
|
-
continual_foragax-0.
|
131
|
+
continual_foragax-0.22.0.dist-info/METADATA,sha256=WCVeg6996zpBsLWNghDpibCAs70CDaI8KStzCFnSPNM,4897
|
132
|
+
continual_foragax-0.22.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
133
|
+
continual_foragax-0.22.0.dist-info/entry_points.txt,sha256=Qiu6iE_XudrDO_bVAMeA435h4PO9ourt8huvSHiuMPc,41
|
134
|
+
continual_foragax-0.22.0.dist-info/top_level.txt,sha256=-z3SDK6RfLIcLI24n8rdbeFzlVY3hunChzlu-v1Fncs,8
|
135
|
+
continual_foragax-0.22.0.dist-info/RECORD,,
|
foragax/env.py
CHANGED
@@ -10,6 +10,7 @@ from typing import Any, Dict, Tuple, Union
|
|
10
10
|
|
11
11
|
import jax
|
12
12
|
import jax.numpy as jnp
|
13
|
+
import numpy as np
|
13
14
|
from flax import struct
|
14
15
|
from gymnax.environments import environment, spaces
|
15
16
|
|
@@ -66,13 +67,16 @@ class ForagaxEnv(environment.Environment):
|
|
66
67
|
|
67
68
|
def __init__(
|
68
69
|
self,
|
70
|
+
name: str = "Foragax-v0",
|
69
71
|
size: Union[Tuple[int, int], int] = (10, 10),
|
70
72
|
aperture_size: Union[Tuple[int, int], int] = (5, 5),
|
71
73
|
objects: Tuple[BaseForagaxObject, ...] = (),
|
72
74
|
biomes: Tuple[Biome, ...] = (Biome(object_frequencies=()),),
|
73
75
|
nowrap: bool = False,
|
76
|
+
deterministic_spawn: bool = False,
|
74
77
|
):
|
75
78
|
super().__init__()
|
79
|
+
self._name = name
|
76
80
|
if isinstance(size, int):
|
77
81
|
size = (size, size)
|
78
82
|
self.size = size
|
@@ -81,6 +85,7 @@ class ForagaxEnv(environment.Environment):
|
|
81
85
|
aperture_size = (aperture_size, aperture_size)
|
82
86
|
self.aperture_size = aperture_size
|
83
87
|
self.nowrap = nowrap
|
88
|
+
self.deterministic_spawn = deterministic_spawn
|
84
89
|
objects = (EMPTY,) + objects
|
85
90
|
if self.nowrap:
|
86
91
|
objects = objects + (PADDING,)
|
@@ -103,12 +108,35 @@ class ForagaxEnv(environment.Environment):
|
|
103
108
|
self.biome_object_frequencies = jnp.array(
|
104
109
|
[b.object_frequencies for b in biomes]
|
105
110
|
)
|
106
|
-
self.biome_starts =
|
111
|
+
self.biome_starts = np.array(
|
107
112
|
[b.start if b.start is not None else (-1, -1) for b in biomes]
|
108
113
|
)
|
109
|
-
self.biome_stops =
|
114
|
+
self.biome_stops = np.array(
|
110
115
|
[b.stop if b.stop is not None else (-1, -1) for b in biomes]
|
111
116
|
)
|
117
|
+
self.biome_sizes = np.prod(self.biome_stops - self.biome_starts, axis=1)
|
118
|
+
self.biome_masks = []
|
119
|
+
for i in range(self.biome_object_frequencies.shape[0]):
|
120
|
+
# Create mask for the biome
|
121
|
+
start = jax.lax.select(
|
122
|
+
self.biome_starts[i, 0] == -1,
|
123
|
+
jnp.array([0, 0]),
|
124
|
+
self.biome_starts[i],
|
125
|
+
)
|
126
|
+
stop = jax.lax.select(
|
127
|
+
self.biome_stops[i, 0] == -1,
|
128
|
+
jnp.array(self.size),
|
129
|
+
self.biome_stops[i],
|
130
|
+
)
|
131
|
+
rows = jnp.arange(self.size[1])[:, None]
|
132
|
+
cols = jnp.arange(self.size[0])
|
133
|
+
mask = (
|
134
|
+
(rows >= start[1])
|
135
|
+
& (rows < stop[1])
|
136
|
+
& (cols >= start[0])
|
137
|
+
& (cols < stop[0])
|
138
|
+
)
|
139
|
+
self.biome_masks.append(mask)
|
112
140
|
|
113
141
|
@property
|
114
142
|
def default_params(self) -> EnvParams:
|
@@ -196,57 +224,18 @@ class ForagaxEnv(environment.Environment):
|
|
196
224
|
self, key: jax.Array, params: EnvParams
|
197
225
|
) -> Tuple[jax.Array, EnvState]:
|
198
226
|
"""Reset environment state."""
|
199
|
-
key, subkey = jax.random.split(key)
|
200
|
-
|
201
227
|
object_grid = jnp.zeros((self.size[1], self.size[0]), dtype=int)
|
202
|
-
|
203
|
-
iter_key = subkey
|
228
|
+
key, iter_key = jax.random.split(key)
|
204
229
|
for i in range(self.biome_object_frequencies.shape[0]):
|
205
230
|
iter_key, biome_key = jax.random.split(iter_key)
|
206
|
-
|
207
|
-
grid_rand = jax.random.uniform(biome_key, (self.size[1], self.size[0]))
|
208
|
-
|
209
|
-
# Create mask for the biome
|
210
|
-
start = jax.lax.select(
|
211
|
-
self.biome_starts[i, 0] == -1,
|
212
|
-
jnp.array([0, 0]),
|
213
|
-
self.biome_starts[i],
|
214
|
-
)
|
215
|
-
stop = jax.lax.select(
|
216
|
-
self.biome_stops[i, 0] == -1,
|
217
|
-
jnp.array(self.size),
|
218
|
-
self.biome_stops[i],
|
219
|
-
)
|
220
|
-
|
221
|
-
rows = jnp.arange(self.size[1])[:, None]
|
222
|
-
cols = jnp.arange(self.size[0])
|
231
|
+
mask = self.biome_masks[i]
|
223
232
|
|
224
|
-
|
225
|
-
(
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
# Generate objects for this biome and update the main grid
|
232
|
-
biome_freqs = self.biome_object_frequencies[i]
|
233
|
-
empty_freq = 1.0 - jnp.sum(biome_freqs)
|
234
|
-
all_freqs = jnp.concatenate([jnp.array([empty_freq]), biome_freqs])
|
235
|
-
|
236
|
-
cumulative_freqs = jnp.cumsum(
|
237
|
-
jnp.concatenate([jnp.array([0.0]), all_freqs])
|
238
|
-
)
|
239
|
-
|
240
|
-
# Determine which object to place in each cell
|
241
|
-
# The last object ID will be used for any value of grid_rand >= cumulative_freqs[-1]
|
242
|
-
# so we don't need to cap grid_rand
|
243
|
-
obj_ids_for_biome = jnp.arange(len(all_freqs))
|
244
|
-
cell_obj_ids = (
|
245
|
-
jnp.searchsorted(cumulative_freqs, grid_rand, side="right") - 1
|
246
|
-
)
|
247
|
-
biome_objects = obj_ids_for_biome[cell_obj_ids]
|
248
|
-
|
249
|
-
object_grid = jnp.where(mask, biome_objects, object_grid)
|
233
|
+
if self.deterministic_spawn:
|
234
|
+
biome_objects = self.generate_biome_new(i, biome_key)
|
235
|
+
object_grid = object_grid.at[mask].set(biome_objects)
|
236
|
+
else:
|
237
|
+
biome_objects = self.generate_biome_old(i, biome_key)
|
238
|
+
object_grid = jnp.where(mask, biome_objects, object_grid)
|
250
239
|
|
251
240
|
# Place agent in the center of the world and ensure the cell is empty.
|
252
241
|
agent_pos = jnp.array([self.size[0] // 2, self.size[1] // 2])
|
@@ -260,6 +249,25 @@ class ForagaxEnv(environment.Environment):
|
|
260
249
|
|
261
250
|
return self.get_obs(state, params), state
|
262
251
|
|
252
|
+
def generate_biome_old(self, i: int, biome_key: jax.Array):
|
253
|
+
biome_freqs = self.biome_object_frequencies[i]
|
254
|
+
grid_rand = jax.random.uniform(biome_key, (self.size[1], self.size[0]))
|
255
|
+
empty_freq = 1.0 - jnp.sum(biome_freqs)
|
256
|
+
all_freqs = jnp.concatenate([jnp.array([empty_freq]), biome_freqs])
|
257
|
+
cumulative_freqs = jnp.cumsum(jnp.concatenate([jnp.array([0.0]), all_freqs]))
|
258
|
+
biome_objects = jnp.searchsorted(cumulative_freqs, grid_rand, side="right") - 1
|
259
|
+
return biome_objects
|
260
|
+
|
261
|
+
def generate_biome_new(self, i: int, biome_key: jax.Array):
|
262
|
+
biome_freqs = self.biome_object_frequencies[i]
|
263
|
+
grid = jnp.linspace(0, 1, self.biome_sizes[i], endpoint=False)
|
264
|
+
biome_objects = len(biome_freqs) - jnp.searchsorted(
|
265
|
+
jnp.cumsum(biome_freqs[::-1]), grid, side="right"
|
266
|
+
)
|
267
|
+
flat_biome_objects = biome_objects.flatten()
|
268
|
+
shuffled_objects = jax.random.permutation(biome_key, flat_biome_objects)
|
269
|
+
return shuffled_objects
|
270
|
+
|
263
271
|
def is_terminal(self, state: EnvState, params: EnvParams) -> jax.Array:
|
264
272
|
"""Foragax is a continuing environment."""
|
265
273
|
return False
|
@@ -267,7 +275,7 @@ class ForagaxEnv(environment.Environment):
|
|
267
275
|
@property
|
268
276
|
def name(self) -> str:
|
269
277
|
"""Environment name."""
|
270
|
-
return
|
278
|
+
return self._name
|
271
279
|
|
272
280
|
@property
|
273
281
|
def num_actions(self) -> int:
|
@@ -438,13 +446,17 @@ class ForagaxObjectEnv(ForagaxEnv):
|
|
438
446
|
|
439
447
|
def __init__(
|
440
448
|
self,
|
449
|
+
name: str = "Foragax-v0",
|
441
450
|
size: Union[Tuple[int, int], int] = (10, 10),
|
442
451
|
aperture_size: Union[Tuple[int, int], int] = (5, 5),
|
443
452
|
objects: Tuple[BaseForagaxObject, ...] = (),
|
444
453
|
biomes: Tuple[Biome, ...] = (Biome(object_frequencies=()),),
|
445
454
|
nowrap: bool = False,
|
455
|
+
deterministic_spawn: bool = False,
|
446
456
|
):
|
447
|
-
super().__init__(
|
457
|
+
super().__init__(
|
458
|
+
name, size, aperture_size, objects, biomes, nowrap, deterministic_spawn
|
459
|
+
)
|
448
460
|
|
449
461
|
# Compute unique colors and mapping for partial observability
|
450
462
|
# Exclude EMPTY (index 0) from color channels
|
foragax/registry.py
CHANGED
@@ -163,6 +163,19 @@ ENV_CONFIGS: Dict[str, Dict[str, Any]] = {
|
|
163
163
|
"biomes": None,
|
164
164
|
"nowrap": True,
|
165
165
|
},
|
166
|
+
"ForagaxTwoBiome-v9": {
|
167
|
+
"size": None,
|
168
|
+
"aperture_size": None,
|
169
|
+
"objects": (
|
170
|
+
BROWN_MOREL_UNIFORM,
|
171
|
+
BROWN_OYSTER_UNIFORM,
|
172
|
+
GREEN_DEATHCAP_UNIFORM,
|
173
|
+
GREEN_FAKE_UNIFORM,
|
174
|
+
),
|
175
|
+
"biomes": None,
|
176
|
+
"nowrap": True,
|
177
|
+
"deterministic_spawn": True,
|
178
|
+
},
|
166
179
|
"ForagaxTwoBiomeSmall-v1": {
|
167
180
|
"size": (16, 8),
|
168
181
|
"aperture_size": None,
|
@@ -233,7 +246,7 @@ def make(
|
|
233
246
|
if nowrap is not None:
|
234
247
|
config["nowrap"] = nowrap
|
235
248
|
|
236
|
-
if env_id in ("ForagaxTwoBiome-v7", "ForagaxTwoBiome-v8"):
|
249
|
+
if env_id in ("ForagaxTwoBiome-v7", "ForagaxTwoBiome-v8", "ForagaxTwoBiome-v9"):
|
237
250
|
margin = aperture_size[1] // 2 + 1
|
238
251
|
width = 2 * margin + 9
|
239
252
|
config["size"] = (width, 15)
|
@@ -286,5 +299,6 @@ def make(
|
|
286
299
|
raise ValueError(f"Unknown observation type: {observation_type}")
|
287
300
|
|
288
301
|
env_class = env_class_map[observation_type]
|
302
|
+
config["name"] = env_id
|
289
303
|
|
290
304
|
return env_class(**config)
|
File without changes
|
File without changes
|
File without changes
|