continual-foragax 0.22.0__py3-none-any.whl → 0.23.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {continual_foragax-0.22.0.dist-info → continual_foragax-0.23.0.dist-info}/METADATA +1 -1
- {continual_foragax-0.22.0.dist-info → continual_foragax-0.23.0.dist-info}/RECORD +8 -8
- foragax/env.py +61 -4
- foragax/objects.py +42 -1
- foragax/registry.py +68 -1
- {continual_foragax-0.22.0.dist-info → continual_foragax-0.23.0.dist-info}/WHEEL +0 -0
- {continual_foragax-0.22.0.dist-info → continual_foragax-0.23.0.dist-info}/entry_points.txt +0 -0
- {continual_foragax-0.22.0.dist-info → continual_foragax-0.23.0.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,8 @@
|
|
1
1
|
foragax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
foragax/colors.py,sha256=rqNPiywP4Nvr0POhsGpasRk-nMMTS3DOwFRUgperlUk,2065
|
3
|
-
foragax/env.py,sha256=
|
4
|
-
foragax/objects.py,sha256=
|
5
|
-
foragax/registry.py,sha256=
|
3
|
+
foragax/env.py,sha256=7csniZeHv2ZX6-8CoYm0pMue7tXmmDwr2XU5msH0n0Q,23680
|
4
|
+
foragax/objects.py,sha256=CblI0NI7PQzeKk3MZa8sbaa9wB4pc_8CyOGbJOFWytE,9391
|
5
|
+
foragax/registry.py,sha256=EbvPn2IpEMZ7HRjkAN7_JWaTJfYgjLhHkqspcT88DiY,11339
|
6
6
|
foragax/rendering.py,sha256=bms7wvBZTofoR-K-2QD2Ggeed7Viw8uwAEiEpEM3eSo,2768
|
7
7
|
foragax/weather.py,sha256=KNAiwuFz8V__6G75vZIWQKPocLzXqxXn-Vt4TbHIpcA,1258
|
8
8
|
foragax/data/ECA_non-blended_custom/TG_SOUID100897.txt,sha256=N7URbX6VlCZvCboUogYjMzy1I-0cfNPOn0QTLSHHfQ0,1776751
|
@@ -128,8 +128,8 @@ foragax/data/ECA_non-blended_custom/TG_SOUID156887.txt,sha256=juzTPgJoJxfqmZkorL
|
|
128
128
|
foragax/data/ECA_non-blended_custom/elements.txt,sha256=OtcUBoDAHxuln79BPKGu0tsQxG_5G2BfAX3Ck130kEA,4507
|
129
129
|
foragax/data/ECA_non-blended_custom/metadata.txt,sha256=nudnmOCy5cPJfSXt_IjyX0S5-T7NkCZREICZSimqeqc,48260
|
130
130
|
foragax/data/ECA_non-blended_custom/sources.txt,sha256=1j3lSmINAoCMqPqFrHfZJriOz6sTYZNOhXzUwvTLas0,20857
|
131
|
-
continual_foragax-0.
|
132
|
-
continual_foragax-0.
|
133
|
-
continual_foragax-0.
|
134
|
-
continual_foragax-0.
|
135
|
-
continual_foragax-0.
|
131
|
+
continual_foragax-0.23.0.dist-info/METADATA,sha256=EIaFemphhlG9KTITgvT_qTQfjKS7sQl7hdrxCceAsW4,4897
|
132
|
+
continual_foragax-0.23.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
133
|
+
continual_foragax-0.23.0.dist-info/entry_points.txt,sha256=Qiu6iE_XudrDO_bVAMeA435h4PO9ourt8huvSHiuMPc,41
|
134
|
+
continual_foragax-0.23.0.dist-info/top_level.txt,sha256=-z3SDK6RfLIcLI24n8rdbeFzlVY3hunChzlu-v1Fncs,8
|
135
|
+
continual_foragax-0.23.0.dist-info/RECORD,,
|
foragax/env.py
CHANGED
@@ -59,6 +59,7 @@ class EnvParams(environment.EnvParams):
|
|
59
59
|
class EnvState(environment.EnvState):
|
60
60
|
pos: jax.Array
|
61
61
|
object_grid: jax.Array
|
62
|
+
biome_grid: jax.Array
|
62
63
|
time: int
|
63
64
|
|
64
65
|
|
@@ -101,6 +102,7 @@ class ForagaxEnv(environment.Environment):
|
|
101
102
|
self.object_blocking = jnp.array([o.blocking for o in objects])
|
102
103
|
self.object_collectable = jnp.array([o.collectable for o in objects])
|
103
104
|
self.object_colors = jnp.array([o.color for o in objects])
|
105
|
+
self.object_random_respawn = jnp.array([o.random_respawn for o in objects])
|
104
106
|
|
105
107
|
self.reward_fns = [o.reward for o in objects]
|
106
108
|
self.regen_delay_fns = [o.regen_delay for o in objects]
|
@@ -179,7 +181,7 @@ class ForagaxEnv(environment.Environment):
|
|
179
181
|
is_collectable = self.object_collectable[obj_at_pos]
|
180
182
|
|
181
183
|
# 3. HANDLE OBJECT COLLECTION AND RESPAWNING
|
182
|
-
key, subkey = jax.random.split(key)
|
184
|
+
key, subkey, rand_key = jax.random.split(key, 3)
|
183
185
|
|
184
186
|
# Decrement timers (stored as negative values)
|
185
187
|
is_timer = state.object_grid < 0
|
@@ -193,10 +195,49 @@ class ForagaxEnv(environment.Environment):
|
|
193
195
|
)
|
194
196
|
encoded_timer = obj_at_pos - ((regen_delay + 1) * num_obj_types)
|
195
197
|
|
198
|
+
def place_at_current_pos(current_grid, timer_val):
|
199
|
+
return current_grid.at[pos[1], pos[0]].set(timer_val)
|
200
|
+
|
201
|
+
def place_at_random_pos(current_grid, timer_val):
|
202
|
+
# Set the collected position to empty temporarily
|
203
|
+
grid = current_grid.at[pos[1], pos[0]].set(0)
|
204
|
+
|
205
|
+
# Find all valid spawn locations (empty cells within the same biome)
|
206
|
+
biome_id = state.biome_grid[pos[1], pos[0]]
|
207
|
+
biome_mask = state.biome_grid == biome_id
|
208
|
+
empty_mask = grid == 0
|
209
|
+
valid_spawn_mask = biome_mask & empty_mask
|
210
|
+
|
211
|
+
num_valid_spawns = jnp.sum(valid_spawn_mask)
|
212
|
+
|
213
|
+
# Get indices of valid spawn locations, padded to a static size
|
214
|
+
y_indices, x_indices = jnp.nonzero(
|
215
|
+
valid_spawn_mask, size=self.size[0] * self.size[1], fill_value=-1
|
216
|
+
)
|
217
|
+
valid_spawn_indices = jnp.stack([y_indices, x_indices], axis=1)
|
218
|
+
|
219
|
+
# Select a random valid location
|
220
|
+
random_idx = jax.random.randint(rand_key, (), 0, num_valid_spawns)
|
221
|
+
new_spawn_pos = valid_spawn_indices[random_idx]
|
222
|
+
|
223
|
+
# Place the timer at the new random position
|
224
|
+
return grid.at[new_spawn_pos[0], new_spawn_pos[1]].set(timer_val)
|
225
|
+
|
196
226
|
# If collected, replace object with timer; otherwise, keep it
|
197
227
|
val_at_pos = object_grid[pos[1], pos[0]]
|
198
|
-
|
199
|
-
|
228
|
+
should_collect = is_collectable & (val_at_pos > 0)
|
229
|
+
|
230
|
+
# When not collecting, the value at the position remains unchanged.
|
231
|
+
# When collecting, we either place the timer at the current position or a random one.
|
232
|
+
object_grid = jax.lax.cond(
|
233
|
+
should_collect,
|
234
|
+
lambda: jax.lax.cond(
|
235
|
+
self.object_random_respawn[obj_at_pos],
|
236
|
+
lambda: place_at_random_pos(object_grid, encoded_timer),
|
237
|
+
lambda: place_at_current_pos(object_grid, encoded_timer),
|
238
|
+
),
|
239
|
+
lambda: object_grid,
|
240
|
+
)
|
200
241
|
|
201
242
|
info = {"discount": self.discount(state, params)}
|
202
243
|
if self.weather_object is not None:
|
@@ -208,6 +249,7 @@ class ForagaxEnv(environment.Environment):
|
|
208
249
|
state = EnvState(
|
209
250
|
pos=pos,
|
210
251
|
object_grid=object_grid,
|
252
|
+
biome_grid=state.biome_grid,
|
211
253
|
time=state.time + 1,
|
212
254
|
)
|
213
255
|
|
@@ -225,10 +267,12 @@ class ForagaxEnv(environment.Environment):
|
|
225
267
|
) -> Tuple[jax.Array, EnvState]:
|
226
268
|
"""Reset environment state."""
|
227
269
|
object_grid = jnp.zeros((self.size[1], self.size[0]), dtype=int)
|
270
|
+
biome_grid = jnp.full((self.size[1], self.size[0]), -1, dtype=int)
|
228
271
|
key, iter_key = jax.random.split(key)
|
229
272
|
for i in range(self.biome_object_frequencies.shape[0]):
|
230
273
|
iter_key, biome_key = jax.random.split(iter_key)
|
231
274
|
mask = self.biome_masks[i]
|
275
|
+
biome_grid = jnp.where(mask, i, biome_grid)
|
232
276
|
|
233
277
|
if self.deterministic_spawn:
|
234
278
|
biome_objects = self.generate_biome_new(i, biome_key)
|
@@ -244,6 +288,7 @@ class ForagaxEnv(environment.Environment):
|
|
244
288
|
state = EnvState(
|
245
289
|
pos=agent_pos,
|
246
290
|
object_grid=object_grid,
|
291
|
+
biome_grid=biome_grid,
|
247
292
|
time=0,
|
248
293
|
)
|
249
294
|
|
@@ -298,6 +343,12 @@ class ForagaxEnv(environment.Environment):
|
|
298
343
|
(self.size[1], self.size[0]),
|
299
344
|
int,
|
300
345
|
),
|
346
|
+
"biome_grid": spaces.Box(
|
347
|
+
0,
|
348
|
+
self.biome_object_frequencies.shape[0],
|
349
|
+
(self.size[1], self.size[0]),
|
350
|
+
int,
|
351
|
+
),
|
301
352
|
"time": spaces.Discrete(params.max_steps_in_episode),
|
302
353
|
}
|
303
354
|
)
|
@@ -455,7 +506,13 @@ class ForagaxObjectEnv(ForagaxEnv):
|
|
455
506
|
deterministic_spawn: bool = False,
|
456
507
|
):
|
457
508
|
super().__init__(
|
458
|
-
name,
|
509
|
+
name,
|
510
|
+
size,
|
511
|
+
aperture_size,
|
512
|
+
objects,
|
513
|
+
biomes,
|
514
|
+
nowrap,
|
515
|
+
deterministic_spawn,
|
459
516
|
)
|
460
517
|
|
461
518
|
# Compute unique colors and mapping for partial observability
|
foragax/objects.py
CHANGED
@@ -16,11 +16,13 @@ class BaseForagaxObject:
|
|
16
16
|
blocking: bool = False,
|
17
17
|
collectable: bool = False,
|
18
18
|
color: Tuple[int, int, int] = (0, 0, 0),
|
19
|
+
random_respawn: bool = False,
|
19
20
|
):
|
20
21
|
self.name = name
|
21
22
|
self.blocking = blocking
|
22
23
|
self.collectable = collectable
|
23
24
|
self.color = color
|
25
|
+
self.random_respawn = random_respawn
|
24
26
|
|
25
27
|
@abc.abstractmethod
|
26
28
|
def reward(self, clock: int, rng: jax.Array) -> float:
|
@@ -44,8 +46,9 @@ class DefaultForagaxObject(BaseForagaxObject):
|
|
44
46
|
collectable: bool = False,
|
45
47
|
regen_delay: Tuple[int, int] = (10, 100),
|
46
48
|
color: Tuple[int, int, int] = (255, 255, 255),
|
49
|
+
random_respawn: bool = False,
|
47
50
|
):
|
48
|
-
super().__init__(name, blocking, collectable, color)
|
51
|
+
super().__init__(name, blocking, collectable, color, random_respawn)
|
49
52
|
self.reward_val = reward
|
50
53
|
self.regen_delay_range = regen_delay
|
51
54
|
|
@@ -70,6 +73,7 @@ class NormalRegenForagaxObject(DefaultForagaxObject):
|
|
70
73
|
mean_regen_delay: int = 10,
|
71
74
|
std_regen_delay: int = 1,
|
72
75
|
color: Tuple[int, int, int] = (0, 0, 0),
|
76
|
+
random_respawn: bool = False,
|
73
77
|
):
|
74
78
|
super().__init__(
|
75
79
|
name=name,
|
@@ -77,6 +81,7 @@ class NormalRegenForagaxObject(DefaultForagaxObject):
|
|
77
81
|
collectable=collectable,
|
78
82
|
regen_delay=(mean_regen_delay, mean_regen_delay),
|
79
83
|
color=color,
|
84
|
+
random_respawn=random_respawn,
|
80
85
|
)
|
81
86
|
self.mean_regen_delay = mean_regen_delay
|
82
87
|
self.std_regen_delay = std_regen_delay
|
@@ -99,6 +104,7 @@ class WeatherObject(NormalRegenForagaxObject):
|
|
99
104
|
mean_regen_delay: int = 10,
|
100
105
|
std_regen_delay: int = 1,
|
101
106
|
color: Tuple[int, int, int] = (0, 0, 0),
|
107
|
+
random_respawn: bool = False,
|
102
108
|
):
|
103
109
|
super().__init__(
|
104
110
|
name=name,
|
@@ -106,6 +112,7 @@ class WeatherObject(NormalRegenForagaxObject):
|
|
106
112
|
mean_regen_delay=mean_regen_delay,
|
107
113
|
std_regen_delay=std_regen_delay,
|
108
114
|
color=color,
|
115
|
+
random_respawn=random_respawn,
|
109
116
|
)
|
110
117
|
self.rewards = rewards
|
111
118
|
self.repeat = repeat
|
@@ -271,6 +278,40 @@ GREEN_FAKE_UNIFORM = DefaultForagaxObject(
|
|
271
278
|
regen_delay=(9, 11),
|
272
279
|
)
|
273
280
|
|
281
|
+
# Random respawn variants
|
282
|
+
BROWN_MOREL_UNIFORM_RANDOM = DefaultForagaxObject(
|
283
|
+
name="brown_morel",
|
284
|
+
reward=10.0,
|
285
|
+
collectable=True,
|
286
|
+
color=(63, 30, 25),
|
287
|
+
regen_delay=(90, 110),
|
288
|
+
random_respawn=True,
|
289
|
+
)
|
290
|
+
BROWN_OYSTER_UNIFORM_RANDOM = DefaultForagaxObject(
|
291
|
+
name="brown_oyster",
|
292
|
+
reward=1.0,
|
293
|
+
collectable=True,
|
294
|
+
color=(63, 30, 25),
|
295
|
+
regen_delay=(9, 11),
|
296
|
+
random_respawn=True,
|
297
|
+
)
|
298
|
+
GREEN_DEATHCAP_UNIFORM_RANDOM = DefaultForagaxObject(
|
299
|
+
name="green_deathcap",
|
300
|
+
reward=-5.0,
|
301
|
+
collectable=True,
|
302
|
+
color=(0, 255, 0),
|
303
|
+
regen_delay=(9, 11),
|
304
|
+
random_respawn=True,
|
305
|
+
)
|
306
|
+
GREEN_FAKE_UNIFORM_RANDOM = DefaultForagaxObject(
|
307
|
+
name="green_fake",
|
308
|
+
reward=0.0,
|
309
|
+
collectable=True,
|
310
|
+
color=(0, 255, 0),
|
311
|
+
regen_delay=(9, 11),
|
312
|
+
random_respawn=True,
|
313
|
+
)
|
314
|
+
|
274
315
|
|
275
316
|
def create_weather_objects(
|
276
317
|
file_index: int = 0,
|
foragax/registry.py
CHANGED
@@ -13,15 +13,19 @@ from foragax.objects import (
|
|
13
13
|
BROWN_MOREL,
|
14
14
|
BROWN_MOREL_2,
|
15
15
|
BROWN_MOREL_UNIFORM,
|
16
|
+
BROWN_MOREL_UNIFORM_RANDOM,
|
16
17
|
BROWN_OYSTER,
|
17
18
|
BROWN_OYSTER_UNIFORM,
|
19
|
+
BROWN_OYSTER_UNIFORM_RANDOM,
|
18
20
|
GREEN_DEATHCAP,
|
19
21
|
GREEN_DEATHCAP_2,
|
20
22
|
GREEN_DEATHCAP_3,
|
21
23
|
GREEN_DEATHCAP_UNIFORM,
|
24
|
+
GREEN_DEATHCAP_UNIFORM_RANDOM,
|
22
25
|
GREEN_FAKE,
|
23
26
|
GREEN_FAKE_2,
|
24
27
|
GREEN_FAKE_UNIFORM,
|
28
|
+
GREEN_FAKE_UNIFORM_RANDOM,
|
25
29
|
LARGE_MOREL,
|
26
30
|
LARGE_OYSTER,
|
27
31
|
MEDIUM_MOREL,
|
@@ -176,6 +180,45 @@ ENV_CONFIGS: Dict[str, Dict[str, Any]] = {
|
|
176
180
|
"nowrap": True,
|
177
181
|
"deterministic_spawn": True,
|
178
182
|
},
|
183
|
+
"ForagaxTwoBiome-v10": {
|
184
|
+
"size": None,
|
185
|
+
"aperture_size": None,
|
186
|
+
"objects": (
|
187
|
+
BROWN_MOREL_UNIFORM_RANDOM,
|
188
|
+
BROWN_OYSTER_UNIFORM_RANDOM,
|
189
|
+
GREEN_DEATHCAP_UNIFORM_RANDOM,
|
190
|
+
GREEN_FAKE_UNIFORM_RANDOM,
|
191
|
+
),
|
192
|
+
"biomes": None,
|
193
|
+
"nowrap": True,
|
194
|
+
"deterministic_spawn": True,
|
195
|
+
},
|
196
|
+
"ForagaxTwoBiome-v11": {
|
197
|
+
"size": None,
|
198
|
+
"aperture_size": None,
|
199
|
+
"objects": (
|
200
|
+
BROWN_MOREL_UNIFORM,
|
201
|
+
BROWN_OYSTER_UNIFORM,
|
202
|
+
GREEN_DEATHCAP_UNIFORM,
|
203
|
+
GREEN_FAKE_UNIFORM,
|
204
|
+
),
|
205
|
+
"biomes": None,
|
206
|
+
"nowrap": True,
|
207
|
+
"deterministic_spawn": True,
|
208
|
+
},
|
209
|
+
"ForagaxTwoBiome-v12": {
|
210
|
+
"size": None,
|
211
|
+
"aperture_size": None,
|
212
|
+
"objects": (
|
213
|
+
BROWN_MOREL_UNIFORM_RANDOM,
|
214
|
+
BROWN_OYSTER_UNIFORM_RANDOM,
|
215
|
+
GREEN_DEATHCAP_UNIFORM_RANDOM,
|
216
|
+
GREEN_FAKE_UNIFORM_RANDOM,
|
217
|
+
),
|
218
|
+
"biomes": None,
|
219
|
+
"nowrap": True,
|
220
|
+
"deterministic_spawn": True,
|
221
|
+
},
|
179
222
|
"ForagaxTwoBiomeSmall-v1": {
|
180
223
|
"size": (16, 8),
|
181
224
|
"aperture_size": None,
|
@@ -246,7 +289,12 @@ def make(
|
|
246
289
|
if nowrap is not None:
|
247
290
|
config["nowrap"] = nowrap
|
248
291
|
|
249
|
-
if env_id in (
|
292
|
+
if env_id in (
|
293
|
+
"ForagaxTwoBiome-v7",
|
294
|
+
"ForagaxTwoBiome-v8",
|
295
|
+
"ForagaxTwoBiome-v9",
|
296
|
+
"ForagaxTwoBiome-v10",
|
297
|
+
):
|
250
298
|
margin = aperture_size[1] // 2 + 1
|
251
299
|
width = 2 * margin + 9
|
252
300
|
config["size"] = (width, 15)
|
@@ -265,6 +313,25 @@ def make(
|
|
265
313
|
),
|
266
314
|
)
|
267
315
|
|
316
|
+
if env_id in ("ForagaxTwoBiome-v11", "ForagaxTwoBiome-v12"):
|
317
|
+
margin = aperture_size[1] // 2 + 1
|
318
|
+
width = 2 * margin + 9
|
319
|
+
config["size"] = (width, 15)
|
320
|
+
config["biomes"] = (
|
321
|
+
# Morel biome
|
322
|
+
Biome(
|
323
|
+
start=(margin, 0),
|
324
|
+
stop=(margin + 2, 15),
|
325
|
+
object_frequencies=(0.5, 0.0, 0.25, 0.0),
|
326
|
+
),
|
327
|
+
# Oyster biome
|
328
|
+
Biome(
|
329
|
+
start=(margin + 7, 0),
|
330
|
+
stop=(margin + 9, 15),
|
331
|
+
object_frequencies=(0.0, 0.5, 0.0, 0.25),
|
332
|
+
),
|
333
|
+
)
|
334
|
+
|
268
335
|
if env_id == "ForagaxWeather-v3":
|
269
336
|
margin = aperture_size[1] // 2 + 1
|
270
337
|
width = 2 * margin + 9
|
File without changes
|
File without changes
|
File without changes
|