continual-foragax 0.20.1__py3-none-any.whl → 0.22.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {continual_foragax-0.20.1.dist-info → continual_foragax-0.22.0.dist-info}/METADATA +1 -1
- {continual_foragax-0.20.1.dist-info → continual_foragax-0.22.0.dist-info}/RECORD +8 -8
- foragax/env.py +63 -51
- foragax/objects.py +28 -0
- foragax/registry.py +31 -1
- {continual_foragax-0.20.1.dist-info → continual_foragax-0.22.0.dist-info}/WHEEL +0 -0
- {continual_foragax-0.20.1.dist-info → continual_foragax-0.22.0.dist-info}/entry_points.txt +0 -0
- {continual_foragax-0.20.1.dist-info → continual_foragax-0.22.0.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,8 @@
|
|
1
1
|
foragax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
foragax/colors.py,sha256=rqNPiywP4Nvr0POhsGpasRk-nMMTS3DOwFRUgperlUk,2065
|
3
|
-
foragax/env.py,sha256=
|
4
|
-
foragax/objects.py,sha256=
|
5
|
-
foragax/registry.py,sha256=
|
3
|
+
foragax/env.py,sha256=Q-96fMoA_51TIJky2JXApITQHXC-1QfdeB5VZvNwe0o,21362
|
4
|
+
foragax/objects.py,sha256=8tBFMiquWCkhOpNndNmzovMjw7lE5P81OOlUvN2F65w,8301
|
5
|
+
foragax/registry.py,sha256=pRBWGP18jd4NKl1H-rwDYaAJKUgRWVfENQ9pvTS0tAw,9462
|
6
6
|
foragax/rendering.py,sha256=bms7wvBZTofoR-K-2QD2Ggeed7Viw8uwAEiEpEM3eSo,2768
|
7
7
|
foragax/weather.py,sha256=KNAiwuFz8V__6G75vZIWQKPocLzXqxXn-Vt4TbHIpcA,1258
|
8
8
|
foragax/data/ECA_non-blended_custom/TG_SOUID100897.txt,sha256=N7URbX6VlCZvCboUogYjMzy1I-0cfNPOn0QTLSHHfQ0,1776751
|
@@ -128,8 +128,8 @@ foragax/data/ECA_non-blended_custom/TG_SOUID156887.txt,sha256=juzTPgJoJxfqmZkorL
|
|
128
128
|
foragax/data/ECA_non-blended_custom/elements.txt,sha256=OtcUBoDAHxuln79BPKGu0tsQxG_5G2BfAX3Ck130kEA,4507
|
129
129
|
foragax/data/ECA_non-blended_custom/metadata.txt,sha256=nudnmOCy5cPJfSXt_IjyX0S5-T7NkCZREICZSimqeqc,48260
|
130
130
|
foragax/data/ECA_non-blended_custom/sources.txt,sha256=1j3lSmINAoCMqPqFrHfZJriOz6sTYZNOhXzUwvTLas0,20857
|
131
|
-
continual_foragax-0.
|
132
|
-
continual_foragax-0.
|
133
|
-
continual_foragax-0.
|
134
|
-
continual_foragax-0.
|
135
|
-
continual_foragax-0.
|
131
|
+
continual_foragax-0.22.0.dist-info/METADATA,sha256=WCVeg6996zpBsLWNghDpibCAs70CDaI8KStzCFnSPNM,4897
|
132
|
+
continual_foragax-0.22.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
133
|
+
continual_foragax-0.22.0.dist-info/entry_points.txt,sha256=Qiu6iE_XudrDO_bVAMeA435h4PO9ourt8huvSHiuMPc,41
|
134
|
+
continual_foragax-0.22.0.dist-info/top_level.txt,sha256=-z3SDK6RfLIcLI24n8rdbeFzlVY3hunChzlu-v1Fncs,8
|
135
|
+
continual_foragax-0.22.0.dist-info/RECORD,,
|
foragax/env.py
CHANGED
@@ -10,6 +10,7 @@ from typing import Any, Dict, Tuple, Union
|
|
10
10
|
|
11
11
|
import jax
|
12
12
|
import jax.numpy as jnp
|
13
|
+
import numpy as np
|
13
14
|
from flax import struct
|
14
15
|
from gymnax.environments import environment, spaces
|
15
16
|
|
@@ -66,13 +67,16 @@ class ForagaxEnv(environment.Environment):
|
|
66
67
|
|
67
68
|
def __init__(
|
68
69
|
self,
|
70
|
+
name: str = "Foragax-v0",
|
69
71
|
size: Union[Tuple[int, int], int] = (10, 10),
|
70
72
|
aperture_size: Union[Tuple[int, int], int] = (5, 5),
|
71
73
|
objects: Tuple[BaseForagaxObject, ...] = (),
|
72
74
|
biomes: Tuple[Biome, ...] = (Biome(object_frequencies=()),),
|
73
75
|
nowrap: bool = False,
|
76
|
+
deterministic_spawn: bool = False,
|
74
77
|
):
|
75
78
|
super().__init__()
|
79
|
+
self._name = name
|
76
80
|
if isinstance(size, int):
|
77
81
|
size = (size, size)
|
78
82
|
self.size = size
|
@@ -81,6 +85,7 @@ class ForagaxEnv(environment.Environment):
|
|
81
85
|
aperture_size = (aperture_size, aperture_size)
|
82
86
|
self.aperture_size = aperture_size
|
83
87
|
self.nowrap = nowrap
|
88
|
+
self.deterministic_spawn = deterministic_spawn
|
84
89
|
objects = (EMPTY,) + objects
|
85
90
|
if self.nowrap:
|
86
91
|
objects = objects + (PADDING,)
|
@@ -103,12 +108,35 @@ class ForagaxEnv(environment.Environment):
|
|
103
108
|
self.biome_object_frequencies = jnp.array(
|
104
109
|
[b.object_frequencies for b in biomes]
|
105
110
|
)
|
106
|
-
self.biome_starts =
|
111
|
+
self.biome_starts = np.array(
|
107
112
|
[b.start if b.start is not None else (-1, -1) for b in biomes]
|
108
113
|
)
|
109
|
-
self.biome_stops =
|
114
|
+
self.biome_stops = np.array(
|
110
115
|
[b.stop if b.stop is not None else (-1, -1) for b in biomes]
|
111
116
|
)
|
117
|
+
self.biome_sizes = np.prod(self.biome_stops - self.biome_starts, axis=1)
|
118
|
+
self.biome_masks = []
|
119
|
+
for i in range(self.biome_object_frequencies.shape[0]):
|
120
|
+
# Create mask for the biome
|
121
|
+
start = jax.lax.select(
|
122
|
+
self.biome_starts[i, 0] == -1,
|
123
|
+
jnp.array([0, 0]),
|
124
|
+
self.biome_starts[i],
|
125
|
+
)
|
126
|
+
stop = jax.lax.select(
|
127
|
+
self.biome_stops[i, 0] == -1,
|
128
|
+
jnp.array(self.size),
|
129
|
+
self.biome_stops[i],
|
130
|
+
)
|
131
|
+
rows = jnp.arange(self.size[1])[:, None]
|
132
|
+
cols = jnp.arange(self.size[0])
|
133
|
+
mask = (
|
134
|
+
(rows >= start[1])
|
135
|
+
& (rows < stop[1])
|
136
|
+
& (cols >= start[0])
|
137
|
+
& (cols < stop[0])
|
138
|
+
)
|
139
|
+
self.biome_masks.append(mask)
|
112
140
|
|
113
141
|
@property
|
114
142
|
def default_params(self) -> EnvParams:
|
@@ -196,57 +224,18 @@ class ForagaxEnv(environment.Environment):
|
|
196
224
|
self, key: jax.Array, params: EnvParams
|
197
225
|
) -> Tuple[jax.Array, EnvState]:
|
198
226
|
"""Reset environment state."""
|
199
|
-
key, subkey = jax.random.split(key)
|
200
|
-
|
201
227
|
object_grid = jnp.zeros((self.size[1], self.size[0]), dtype=int)
|
202
|
-
|
203
|
-
iter_key = subkey
|
228
|
+
key, iter_key = jax.random.split(key)
|
204
229
|
for i in range(self.biome_object_frequencies.shape[0]):
|
205
230
|
iter_key, biome_key = jax.random.split(iter_key)
|
206
|
-
|
207
|
-
grid_rand = jax.random.uniform(biome_key, (self.size[1], self.size[0]))
|
208
|
-
|
209
|
-
# Create mask for the biome
|
210
|
-
start = jax.lax.select(
|
211
|
-
self.biome_starts[i, 0] == -1,
|
212
|
-
jnp.array([0, 0]),
|
213
|
-
self.biome_starts[i],
|
214
|
-
)
|
215
|
-
stop = jax.lax.select(
|
216
|
-
self.biome_stops[i, 0] == -1,
|
217
|
-
jnp.array(self.size),
|
218
|
-
self.biome_stops[i],
|
219
|
-
)
|
220
|
-
|
221
|
-
rows = jnp.arange(self.size[1])[:, None]
|
222
|
-
cols = jnp.arange(self.size[0])
|
231
|
+
mask = self.biome_masks[i]
|
223
232
|
|
224
|
-
|
225
|
-
(
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
# Generate objects for this biome and update the main grid
|
232
|
-
biome_freqs = self.biome_object_frequencies[i]
|
233
|
-
empty_freq = 1.0 - jnp.sum(biome_freqs)
|
234
|
-
all_freqs = jnp.concatenate([jnp.array([empty_freq]), biome_freqs])
|
235
|
-
|
236
|
-
cumulative_freqs = jnp.cumsum(
|
237
|
-
jnp.concatenate([jnp.array([0.0]), all_freqs])
|
238
|
-
)
|
239
|
-
|
240
|
-
# Determine which object to place in each cell
|
241
|
-
# The last object ID will be used for any value of grid_rand >= cumulative_freqs[-1]
|
242
|
-
# so we don't need to cap grid_rand
|
243
|
-
obj_ids_for_biome = jnp.arange(len(all_freqs))
|
244
|
-
cell_obj_ids = (
|
245
|
-
jnp.searchsorted(cumulative_freqs, grid_rand, side="right") - 1
|
246
|
-
)
|
247
|
-
biome_objects = obj_ids_for_biome[cell_obj_ids]
|
248
|
-
|
249
|
-
object_grid = jnp.where(mask, biome_objects, object_grid)
|
233
|
+
if self.deterministic_spawn:
|
234
|
+
biome_objects = self.generate_biome_new(i, biome_key)
|
235
|
+
object_grid = object_grid.at[mask].set(biome_objects)
|
236
|
+
else:
|
237
|
+
biome_objects = self.generate_biome_old(i, biome_key)
|
238
|
+
object_grid = jnp.where(mask, biome_objects, object_grid)
|
250
239
|
|
251
240
|
# Place agent in the center of the world and ensure the cell is empty.
|
252
241
|
agent_pos = jnp.array([self.size[0] // 2, self.size[1] // 2])
|
@@ -260,6 +249,25 @@ class ForagaxEnv(environment.Environment):
|
|
260
249
|
|
261
250
|
return self.get_obs(state, params), state
|
262
251
|
|
252
|
+
def generate_biome_old(self, i: int, biome_key: jax.Array):
|
253
|
+
biome_freqs = self.biome_object_frequencies[i]
|
254
|
+
grid_rand = jax.random.uniform(biome_key, (self.size[1], self.size[0]))
|
255
|
+
empty_freq = 1.0 - jnp.sum(biome_freqs)
|
256
|
+
all_freqs = jnp.concatenate([jnp.array([empty_freq]), biome_freqs])
|
257
|
+
cumulative_freqs = jnp.cumsum(jnp.concatenate([jnp.array([0.0]), all_freqs]))
|
258
|
+
biome_objects = jnp.searchsorted(cumulative_freqs, grid_rand, side="right") - 1
|
259
|
+
return biome_objects
|
260
|
+
|
261
|
+
def generate_biome_new(self, i: int, biome_key: jax.Array):
|
262
|
+
biome_freqs = self.biome_object_frequencies[i]
|
263
|
+
grid = jnp.linspace(0, 1, self.biome_sizes[i], endpoint=False)
|
264
|
+
biome_objects = len(biome_freqs) - jnp.searchsorted(
|
265
|
+
jnp.cumsum(biome_freqs[::-1]), grid, side="right"
|
266
|
+
)
|
267
|
+
flat_biome_objects = biome_objects.flatten()
|
268
|
+
shuffled_objects = jax.random.permutation(biome_key, flat_biome_objects)
|
269
|
+
return shuffled_objects
|
270
|
+
|
263
271
|
def is_terminal(self, state: EnvState, params: EnvParams) -> jax.Array:
|
264
272
|
"""Foragax is a continuing environment."""
|
265
273
|
return False
|
@@ -267,7 +275,7 @@ class ForagaxEnv(environment.Environment):
|
|
267
275
|
@property
|
268
276
|
def name(self) -> str:
|
269
277
|
"""Environment name."""
|
270
|
-
return
|
278
|
+
return self._name
|
271
279
|
|
272
280
|
@property
|
273
281
|
def num_actions(self) -> int:
|
@@ -438,13 +446,17 @@ class ForagaxObjectEnv(ForagaxEnv):
|
|
438
446
|
|
439
447
|
def __init__(
|
440
448
|
self,
|
449
|
+
name: str = "Foragax-v0",
|
441
450
|
size: Union[Tuple[int, int], int] = (10, 10),
|
442
451
|
aperture_size: Union[Tuple[int, int], int] = (5, 5),
|
443
452
|
objects: Tuple[BaseForagaxObject, ...] = (),
|
444
453
|
biomes: Tuple[Biome, ...] = (Biome(object_frequencies=()),),
|
445
454
|
nowrap: bool = False,
|
455
|
+
deterministic_spawn: bool = False,
|
446
456
|
):
|
447
|
-
super().__init__(
|
457
|
+
super().__init__(
|
458
|
+
name, size, aperture_size, objects, biomes, nowrap, deterministic_spawn
|
459
|
+
)
|
448
460
|
|
449
461
|
# Compute unique colors and mapping for partial observability
|
450
462
|
# Exclude EMPTY (index 0) from color channels
|
foragax/objects.py
CHANGED
@@ -242,6 +242,34 @@ GREEN_FAKE_2 = NormalRegenForagaxObject(
|
|
242
242
|
mean_regen_delay=10,
|
243
243
|
std_regen_delay=1,
|
244
244
|
)
|
245
|
+
BROWN_MOREL_UNIFORM = DefaultForagaxObject(
|
246
|
+
name="brown_morel",
|
247
|
+
reward=10.0,
|
248
|
+
collectable=True,
|
249
|
+
color=(63, 30, 25),
|
250
|
+
regen_delay=(90, 110),
|
251
|
+
)
|
252
|
+
BROWN_OYSTER_UNIFORM = DefaultForagaxObject(
|
253
|
+
name="brown_oyster",
|
254
|
+
reward=1.0,
|
255
|
+
collectable=True,
|
256
|
+
color=(63, 30, 25),
|
257
|
+
regen_delay=(9, 11),
|
258
|
+
)
|
259
|
+
GREEN_DEATHCAP_UNIFORM = DefaultForagaxObject(
|
260
|
+
name="green_deathcap",
|
261
|
+
reward=-5.0,
|
262
|
+
collectable=True,
|
263
|
+
color=(0, 255, 0),
|
264
|
+
regen_delay=(9, 11),
|
265
|
+
)
|
266
|
+
GREEN_FAKE_UNIFORM = DefaultForagaxObject(
|
267
|
+
name="green_fake",
|
268
|
+
reward=0.0,
|
269
|
+
collectable=True,
|
270
|
+
color=(0, 255, 0),
|
271
|
+
regen_delay=(9, 11),
|
272
|
+
)
|
245
273
|
|
246
274
|
|
247
275
|
def create_weather_objects(
|
foragax/registry.py
CHANGED
@@ -12,12 +12,16 @@ from foragax.env import (
|
|
12
12
|
from foragax.objects import (
|
13
13
|
BROWN_MOREL,
|
14
14
|
BROWN_MOREL_2,
|
15
|
+
BROWN_MOREL_UNIFORM,
|
15
16
|
BROWN_OYSTER,
|
17
|
+
BROWN_OYSTER_UNIFORM,
|
16
18
|
GREEN_DEATHCAP,
|
17
19
|
GREEN_DEATHCAP_2,
|
18
20
|
GREEN_DEATHCAP_3,
|
21
|
+
GREEN_DEATHCAP_UNIFORM,
|
19
22
|
GREEN_FAKE,
|
20
23
|
GREEN_FAKE_2,
|
24
|
+
GREEN_FAKE_UNIFORM,
|
21
25
|
LARGE_MOREL,
|
22
26
|
LARGE_OYSTER,
|
23
27
|
MEDIUM_MOREL,
|
@@ -147,6 +151,31 @@ ENV_CONFIGS: Dict[str, Dict[str, Any]] = {
|
|
147
151
|
"biomes": None,
|
148
152
|
"nowrap": True,
|
149
153
|
},
|
154
|
+
"ForagaxTwoBiome-v8": {
|
155
|
+
"size": None,
|
156
|
+
"aperture_size": None,
|
157
|
+
"objects": (
|
158
|
+
BROWN_MOREL_UNIFORM,
|
159
|
+
BROWN_OYSTER_UNIFORM,
|
160
|
+
GREEN_DEATHCAP_UNIFORM,
|
161
|
+
GREEN_FAKE_UNIFORM,
|
162
|
+
),
|
163
|
+
"biomes": None,
|
164
|
+
"nowrap": True,
|
165
|
+
},
|
166
|
+
"ForagaxTwoBiome-v9": {
|
167
|
+
"size": None,
|
168
|
+
"aperture_size": None,
|
169
|
+
"objects": (
|
170
|
+
BROWN_MOREL_UNIFORM,
|
171
|
+
BROWN_OYSTER_UNIFORM,
|
172
|
+
GREEN_DEATHCAP_UNIFORM,
|
173
|
+
GREEN_FAKE_UNIFORM,
|
174
|
+
),
|
175
|
+
"biomes": None,
|
176
|
+
"nowrap": True,
|
177
|
+
"deterministic_spawn": True,
|
178
|
+
},
|
150
179
|
"ForagaxTwoBiomeSmall-v1": {
|
151
180
|
"size": (16, 8),
|
152
181
|
"aperture_size": None,
|
@@ -217,7 +246,7 @@ def make(
|
|
217
246
|
if nowrap is not None:
|
218
247
|
config["nowrap"] = nowrap
|
219
248
|
|
220
|
-
if env_id
|
249
|
+
if env_id in ("ForagaxTwoBiome-v7", "ForagaxTwoBiome-v8", "ForagaxTwoBiome-v9"):
|
221
250
|
margin = aperture_size[1] // 2 + 1
|
222
251
|
width = 2 * margin + 9
|
223
252
|
config["size"] = (width, 15)
|
@@ -270,5 +299,6 @@ def make(
|
|
270
299
|
raise ValueError(f"Unknown observation type: {observation_type}")
|
271
300
|
|
272
301
|
env_class = env_class_map[observation_type]
|
302
|
+
config["name"] = env_id
|
273
303
|
|
274
304
|
return env_class(**config)
|
File without changes
|
File without changes
|
File without changes
|