continual-foragax 0.22.0__py3-none-any.whl → 0.23.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {continual_foragax-0.22.0.dist-info → continual_foragax-0.23.1.dist-info}/METADATA +1 -1
- {continual_foragax-0.22.0.dist-info → continual_foragax-0.23.1.dist-info}/RECORD +8 -8
- foragax/env.py +62 -6
- foragax/objects.py +42 -1
- foragax/registry.py +68 -1
- {continual_foragax-0.22.0.dist-info → continual_foragax-0.23.1.dist-info}/WHEEL +0 -0
- {continual_foragax-0.22.0.dist-info → continual_foragax-0.23.1.dist-info}/entry_points.txt +0 -0
- {continual_foragax-0.22.0.dist-info → continual_foragax-0.23.1.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,8 @@
|
|
1
1
|
foragax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
foragax/colors.py,sha256=rqNPiywP4Nvr0POhsGpasRk-nMMTS3DOwFRUgperlUk,2065
|
3
|
-
foragax/env.py,sha256=
|
4
|
-
foragax/objects.py,sha256=
|
5
|
-
foragax/registry.py,sha256=
|
3
|
+
foragax/env.py,sha256=t0tb4IiDFiC3xJLLv7s_IFOYgoyRdje1PcnEvJNcyVM,23578
|
4
|
+
foragax/objects.py,sha256=CblI0NI7PQzeKk3MZa8sbaa9wB4pc_8CyOGbJOFWytE,9391
|
5
|
+
foragax/registry.py,sha256=EbvPn2IpEMZ7HRjkAN7_JWaTJfYgjLhHkqspcT88DiY,11339
|
6
6
|
foragax/rendering.py,sha256=bms7wvBZTofoR-K-2QD2Ggeed7Viw8uwAEiEpEM3eSo,2768
|
7
7
|
foragax/weather.py,sha256=KNAiwuFz8V__6G75vZIWQKPocLzXqxXn-Vt4TbHIpcA,1258
|
8
8
|
foragax/data/ECA_non-blended_custom/TG_SOUID100897.txt,sha256=N7URbX6VlCZvCboUogYjMzy1I-0cfNPOn0QTLSHHfQ0,1776751
|
@@ -128,8 +128,8 @@ foragax/data/ECA_non-blended_custom/TG_SOUID156887.txt,sha256=juzTPgJoJxfqmZkorL
|
|
128
128
|
foragax/data/ECA_non-blended_custom/elements.txt,sha256=OtcUBoDAHxuln79BPKGu0tsQxG_5G2BfAX3Ck130kEA,4507
|
129
129
|
foragax/data/ECA_non-blended_custom/metadata.txt,sha256=nudnmOCy5cPJfSXt_IjyX0S5-T7NkCZREICZSimqeqc,48260
|
130
130
|
foragax/data/ECA_non-blended_custom/sources.txt,sha256=1j3lSmINAoCMqPqFrHfZJriOz6sTYZNOhXzUwvTLas0,20857
|
131
|
-
continual_foragax-0.
|
132
|
-
continual_foragax-0.
|
133
|
-
continual_foragax-0.
|
134
|
-
continual_foragax-0.
|
135
|
-
continual_foragax-0.
|
131
|
+
continual_foragax-0.23.1.dist-info/METADATA,sha256=8r-zJyus4axTuoqGPIOKo0-cKRIIT0Phtcd3adW2kVs,4897
|
132
|
+
continual_foragax-0.23.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
133
|
+
continual_foragax-0.23.1.dist-info/entry_points.txt,sha256=Qiu6iE_XudrDO_bVAMeA435h4PO9ourt8huvSHiuMPc,41
|
134
|
+
continual_foragax-0.23.1.dist-info/top_level.txt,sha256=-z3SDK6RfLIcLI24n8rdbeFzlVY3hunChzlu-v1Fncs,8
|
135
|
+
continual_foragax-0.23.1.dist-info/RECORD,,
|
foragax/env.py
CHANGED
@@ -59,6 +59,7 @@ class EnvParams(environment.EnvParams):
|
|
59
59
|
class EnvState(environment.EnvState):
|
60
60
|
pos: jax.Array
|
61
61
|
object_grid: jax.Array
|
62
|
+
biome_grid: jax.Array
|
62
63
|
time: int
|
63
64
|
|
64
65
|
|
@@ -101,6 +102,7 @@ class ForagaxEnv(environment.Environment):
|
|
101
102
|
self.object_blocking = jnp.array([o.blocking for o in objects])
|
102
103
|
self.object_collectable = jnp.array([o.collectable for o in objects])
|
103
104
|
self.object_colors = jnp.array([o.color for o in objects])
|
105
|
+
self.object_random_respawn = jnp.array([o.random_respawn for o in objects])
|
104
106
|
|
105
107
|
self.reward_fns = [o.reward for o in objects]
|
106
108
|
self.regen_delay_fns = [o.regen_delay for o in objects]
|
@@ -179,7 +181,7 @@ class ForagaxEnv(environment.Environment):
|
|
179
181
|
is_collectable = self.object_collectable[obj_at_pos]
|
180
182
|
|
181
183
|
# 3. HANDLE OBJECT COLLECTION AND RESPAWNING
|
182
|
-
key, subkey = jax.random.split(key)
|
184
|
+
key, subkey, rand_key = jax.random.split(key, 3)
|
183
185
|
|
184
186
|
# Decrement timers (stored as negative values)
|
185
187
|
is_timer = state.object_grid < 0
|
@@ -193,10 +195,49 @@ class ForagaxEnv(environment.Environment):
|
|
193
195
|
)
|
194
196
|
encoded_timer = obj_at_pos - ((regen_delay + 1) * num_obj_types)
|
195
197
|
|
198
|
+
def place_at_current_pos(current_grid, timer_val):
|
199
|
+
return current_grid.at[pos[1], pos[0]].set(timer_val)
|
200
|
+
|
201
|
+
def place_at_random_pos(current_grid, timer_val):
|
202
|
+
# Set the collected position to empty temporarily
|
203
|
+
grid = current_grid.at[pos[1], pos[0]].set(0)
|
204
|
+
|
205
|
+
# Find all valid spawn locations (empty cells within the same biome)
|
206
|
+
biome_id = state.biome_grid[pos[1], pos[0]]
|
207
|
+
biome_mask = state.biome_grid == biome_id
|
208
|
+
empty_mask = grid == 0
|
209
|
+
valid_spawn_mask = biome_mask & empty_mask
|
210
|
+
|
211
|
+
num_valid_spawns = jnp.sum(valid_spawn_mask)
|
212
|
+
|
213
|
+
# Get indices of valid spawn locations, padded to a static size
|
214
|
+
y_indices, x_indices = jnp.nonzero(
|
215
|
+
valid_spawn_mask, size=self.size[0] * self.size[1], fill_value=-1
|
216
|
+
)
|
217
|
+
valid_spawn_indices = jnp.stack([y_indices, x_indices], axis=1)
|
218
|
+
|
219
|
+
# Select a random valid location
|
220
|
+
random_idx = jax.random.randint(rand_key, (), 0, num_valid_spawns)
|
221
|
+
new_spawn_pos = valid_spawn_indices[random_idx]
|
222
|
+
|
223
|
+
# Place the timer at the new random position
|
224
|
+
return grid.at[new_spawn_pos[0], new_spawn_pos[1]].set(timer_val)
|
225
|
+
|
196
226
|
# If collected, replace object with timer; otherwise, keep it
|
197
227
|
val_at_pos = object_grid[pos[1], pos[0]]
|
198
|
-
|
199
|
-
|
228
|
+
should_collect = is_collectable & (val_at_pos > 0)
|
229
|
+
|
230
|
+
# When not collecting, the value at the position remains unchanged.
|
231
|
+
# When collecting, we either place the timer at the current position or a random one.
|
232
|
+
object_grid = jax.lax.cond(
|
233
|
+
should_collect,
|
234
|
+
lambda: jax.lax.cond(
|
235
|
+
self.object_random_respawn[obj_at_pos],
|
236
|
+
lambda: place_at_random_pos(object_grid, encoded_timer),
|
237
|
+
lambda: place_at_current_pos(object_grid, encoded_timer),
|
238
|
+
),
|
239
|
+
lambda: object_grid,
|
240
|
+
)
|
200
241
|
|
201
242
|
info = {"discount": self.discount(state, params)}
|
202
243
|
if self.weather_object is not None:
|
@@ -208,6 +249,7 @@ class ForagaxEnv(environment.Environment):
|
|
208
249
|
state = EnvState(
|
209
250
|
pos=pos,
|
210
251
|
object_grid=object_grid,
|
252
|
+
biome_grid=state.biome_grid,
|
211
253
|
time=state.time + 1,
|
212
254
|
)
|
213
255
|
|
@@ -225,10 +267,12 @@ class ForagaxEnv(environment.Environment):
|
|
225
267
|
) -> Tuple[jax.Array, EnvState]:
|
226
268
|
"""Reset environment state."""
|
227
269
|
object_grid = jnp.zeros((self.size[1], self.size[0]), dtype=int)
|
270
|
+
biome_grid = jnp.full((self.size[1], self.size[0]), -1, dtype=int)
|
228
271
|
key, iter_key = jax.random.split(key)
|
229
272
|
for i in range(self.biome_object_frequencies.shape[0]):
|
230
273
|
iter_key, biome_key = jax.random.split(iter_key)
|
231
274
|
mask = self.biome_masks[i]
|
275
|
+
biome_grid = jnp.where(mask, i, biome_grid)
|
232
276
|
|
233
277
|
if self.deterministic_spawn:
|
234
278
|
biome_objects = self.generate_biome_new(i, biome_key)
|
@@ -237,13 +281,13 @@ class ForagaxEnv(environment.Environment):
|
|
237
281
|
biome_objects = self.generate_biome_old(i, biome_key)
|
238
282
|
object_grid = jnp.where(mask, biome_objects, object_grid)
|
239
283
|
|
240
|
-
# Place agent in the center of the world
|
284
|
+
# Place agent in the center of the world
|
241
285
|
agent_pos = jnp.array([self.size[0] // 2, self.size[1] // 2])
|
242
|
-
object_grid = object_grid.at[agent_pos[1], agent_pos[0]].set(0)
|
243
286
|
|
244
287
|
state = EnvState(
|
245
288
|
pos=agent_pos,
|
246
289
|
object_grid=object_grid,
|
290
|
+
biome_grid=biome_grid,
|
247
291
|
time=0,
|
248
292
|
)
|
249
293
|
|
@@ -298,6 +342,12 @@ class ForagaxEnv(environment.Environment):
|
|
298
342
|
(self.size[1], self.size[0]),
|
299
343
|
int,
|
300
344
|
),
|
345
|
+
"biome_grid": spaces.Box(
|
346
|
+
0,
|
347
|
+
self.biome_object_frequencies.shape[0],
|
348
|
+
(self.size[1], self.size[0]),
|
349
|
+
int,
|
350
|
+
),
|
301
351
|
"time": spaces.Discrete(params.max_steps_in_episode),
|
302
352
|
}
|
303
353
|
)
|
@@ -455,7 +505,13 @@ class ForagaxObjectEnv(ForagaxEnv):
|
|
455
505
|
deterministic_spawn: bool = False,
|
456
506
|
):
|
457
507
|
super().__init__(
|
458
|
-
name,
|
508
|
+
name,
|
509
|
+
size,
|
510
|
+
aperture_size,
|
511
|
+
objects,
|
512
|
+
biomes,
|
513
|
+
nowrap,
|
514
|
+
deterministic_spawn,
|
459
515
|
)
|
460
516
|
|
461
517
|
# Compute unique colors and mapping for partial observability
|
foragax/objects.py
CHANGED
@@ -16,11 +16,13 @@ class BaseForagaxObject:
|
|
16
16
|
blocking: bool = False,
|
17
17
|
collectable: bool = False,
|
18
18
|
color: Tuple[int, int, int] = (0, 0, 0),
|
19
|
+
random_respawn: bool = False,
|
19
20
|
):
|
20
21
|
self.name = name
|
21
22
|
self.blocking = blocking
|
22
23
|
self.collectable = collectable
|
23
24
|
self.color = color
|
25
|
+
self.random_respawn = random_respawn
|
24
26
|
|
25
27
|
@abc.abstractmethod
|
26
28
|
def reward(self, clock: int, rng: jax.Array) -> float:
|
@@ -44,8 +46,9 @@ class DefaultForagaxObject(BaseForagaxObject):
|
|
44
46
|
collectable: bool = False,
|
45
47
|
regen_delay: Tuple[int, int] = (10, 100),
|
46
48
|
color: Tuple[int, int, int] = (255, 255, 255),
|
49
|
+
random_respawn: bool = False,
|
47
50
|
):
|
48
|
-
super().__init__(name, blocking, collectable, color)
|
51
|
+
super().__init__(name, blocking, collectable, color, random_respawn)
|
49
52
|
self.reward_val = reward
|
50
53
|
self.regen_delay_range = regen_delay
|
51
54
|
|
@@ -70,6 +73,7 @@ class NormalRegenForagaxObject(DefaultForagaxObject):
|
|
70
73
|
mean_regen_delay: int = 10,
|
71
74
|
std_regen_delay: int = 1,
|
72
75
|
color: Tuple[int, int, int] = (0, 0, 0),
|
76
|
+
random_respawn: bool = False,
|
73
77
|
):
|
74
78
|
super().__init__(
|
75
79
|
name=name,
|
@@ -77,6 +81,7 @@ class NormalRegenForagaxObject(DefaultForagaxObject):
|
|
77
81
|
collectable=collectable,
|
78
82
|
regen_delay=(mean_regen_delay, mean_regen_delay),
|
79
83
|
color=color,
|
84
|
+
random_respawn=random_respawn,
|
80
85
|
)
|
81
86
|
self.mean_regen_delay = mean_regen_delay
|
82
87
|
self.std_regen_delay = std_regen_delay
|
@@ -99,6 +104,7 @@ class WeatherObject(NormalRegenForagaxObject):
|
|
99
104
|
mean_regen_delay: int = 10,
|
100
105
|
std_regen_delay: int = 1,
|
101
106
|
color: Tuple[int, int, int] = (0, 0, 0),
|
107
|
+
random_respawn: bool = False,
|
102
108
|
):
|
103
109
|
super().__init__(
|
104
110
|
name=name,
|
@@ -106,6 +112,7 @@ class WeatherObject(NormalRegenForagaxObject):
|
|
106
112
|
mean_regen_delay=mean_regen_delay,
|
107
113
|
std_regen_delay=std_regen_delay,
|
108
114
|
color=color,
|
115
|
+
random_respawn=random_respawn,
|
109
116
|
)
|
110
117
|
self.rewards = rewards
|
111
118
|
self.repeat = repeat
|
@@ -271,6 +278,40 @@ GREEN_FAKE_UNIFORM = DefaultForagaxObject(
|
|
271
278
|
regen_delay=(9, 11),
|
272
279
|
)
|
273
280
|
|
281
|
+
# Random respawn variants
|
282
|
+
BROWN_MOREL_UNIFORM_RANDOM = DefaultForagaxObject(
|
283
|
+
name="brown_morel",
|
284
|
+
reward=10.0,
|
285
|
+
collectable=True,
|
286
|
+
color=(63, 30, 25),
|
287
|
+
regen_delay=(90, 110),
|
288
|
+
random_respawn=True,
|
289
|
+
)
|
290
|
+
BROWN_OYSTER_UNIFORM_RANDOM = DefaultForagaxObject(
|
291
|
+
name="brown_oyster",
|
292
|
+
reward=1.0,
|
293
|
+
collectable=True,
|
294
|
+
color=(63, 30, 25),
|
295
|
+
regen_delay=(9, 11),
|
296
|
+
random_respawn=True,
|
297
|
+
)
|
298
|
+
GREEN_DEATHCAP_UNIFORM_RANDOM = DefaultForagaxObject(
|
299
|
+
name="green_deathcap",
|
300
|
+
reward=-5.0,
|
301
|
+
collectable=True,
|
302
|
+
color=(0, 255, 0),
|
303
|
+
regen_delay=(9, 11),
|
304
|
+
random_respawn=True,
|
305
|
+
)
|
306
|
+
GREEN_FAKE_UNIFORM_RANDOM = DefaultForagaxObject(
|
307
|
+
name="green_fake",
|
308
|
+
reward=0.0,
|
309
|
+
collectable=True,
|
310
|
+
color=(0, 255, 0),
|
311
|
+
regen_delay=(9, 11),
|
312
|
+
random_respawn=True,
|
313
|
+
)
|
314
|
+
|
274
315
|
|
275
316
|
def create_weather_objects(
|
276
317
|
file_index: int = 0,
|
foragax/registry.py
CHANGED
@@ -13,15 +13,19 @@ from foragax.objects import (
|
|
13
13
|
BROWN_MOREL,
|
14
14
|
BROWN_MOREL_2,
|
15
15
|
BROWN_MOREL_UNIFORM,
|
16
|
+
BROWN_MOREL_UNIFORM_RANDOM,
|
16
17
|
BROWN_OYSTER,
|
17
18
|
BROWN_OYSTER_UNIFORM,
|
19
|
+
BROWN_OYSTER_UNIFORM_RANDOM,
|
18
20
|
GREEN_DEATHCAP,
|
19
21
|
GREEN_DEATHCAP_2,
|
20
22
|
GREEN_DEATHCAP_3,
|
21
23
|
GREEN_DEATHCAP_UNIFORM,
|
24
|
+
GREEN_DEATHCAP_UNIFORM_RANDOM,
|
22
25
|
GREEN_FAKE,
|
23
26
|
GREEN_FAKE_2,
|
24
27
|
GREEN_FAKE_UNIFORM,
|
28
|
+
GREEN_FAKE_UNIFORM_RANDOM,
|
25
29
|
LARGE_MOREL,
|
26
30
|
LARGE_OYSTER,
|
27
31
|
MEDIUM_MOREL,
|
@@ -176,6 +180,45 @@ ENV_CONFIGS: Dict[str, Dict[str, Any]] = {
|
|
176
180
|
"nowrap": True,
|
177
181
|
"deterministic_spawn": True,
|
178
182
|
},
|
183
|
+
"ForagaxTwoBiome-v10": {
|
184
|
+
"size": None,
|
185
|
+
"aperture_size": None,
|
186
|
+
"objects": (
|
187
|
+
BROWN_MOREL_UNIFORM_RANDOM,
|
188
|
+
BROWN_OYSTER_UNIFORM_RANDOM,
|
189
|
+
GREEN_DEATHCAP_UNIFORM_RANDOM,
|
190
|
+
GREEN_FAKE_UNIFORM_RANDOM,
|
191
|
+
),
|
192
|
+
"biomes": None,
|
193
|
+
"nowrap": True,
|
194
|
+
"deterministic_spawn": True,
|
195
|
+
},
|
196
|
+
"ForagaxTwoBiome-v11": {
|
197
|
+
"size": None,
|
198
|
+
"aperture_size": None,
|
199
|
+
"objects": (
|
200
|
+
BROWN_MOREL_UNIFORM,
|
201
|
+
BROWN_OYSTER_UNIFORM,
|
202
|
+
GREEN_DEATHCAP_UNIFORM,
|
203
|
+
GREEN_FAKE_UNIFORM,
|
204
|
+
),
|
205
|
+
"biomes": None,
|
206
|
+
"nowrap": True,
|
207
|
+
"deterministic_spawn": True,
|
208
|
+
},
|
209
|
+
"ForagaxTwoBiome-v12": {
|
210
|
+
"size": None,
|
211
|
+
"aperture_size": None,
|
212
|
+
"objects": (
|
213
|
+
BROWN_MOREL_UNIFORM_RANDOM,
|
214
|
+
BROWN_OYSTER_UNIFORM_RANDOM,
|
215
|
+
GREEN_DEATHCAP_UNIFORM_RANDOM,
|
216
|
+
GREEN_FAKE_UNIFORM_RANDOM,
|
217
|
+
),
|
218
|
+
"biomes": None,
|
219
|
+
"nowrap": True,
|
220
|
+
"deterministic_spawn": True,
|
221
|
+
},
|
179
222
|
"ForagaxTwoBiomeSmall-v1": {
|
180
223
|
"size": (16, 8),
|
181
224
|
"aperture_size": None,
|
@@ -246,7 +289,12 @@ def make(
|
|
246
289
|
if nowrap is not None:
|
247
290
|
config["nowrap"] = nowrap
|
248
291
|
|
249
|
-
if env_id in (
|
292
|
+
if env_id in (
|
293
|
+
"ForagaxTwoBiome-v7",
|
294
|
+
"ForagaxTwoBiome-v8",
|
295
|
+
"ForagaxTwoBiome-v9",
|
296
|
+
"ForagaxTwoBiome-v10",
|
297
|
+
):
|
250
298
|
margin = aperture_size[1] // 2 + 1
|
251
299
|
width = 2 * margin + 9
|
252
300
|
config["size"] = (width, 15)
|
@@ -265,6 +313,25 @@ def make(
|
|
265
313
|
),
|
266
314
|
)
|
267
315
|
|
316
|
+
if env_id in ("ForagaxTwoBiome-v11", "ForagaxTwoBiome-v12"):
|
317
|
+
margin = aperture_size[1] // 2 + 1
|
318
|
+
width = 2 * margin + 9
|
319
|
+
config["size"] = (width, 15)
|
320
|
+
config["biomes"] = (
|
321
|
+
# Morel biome
|
322
|
+
Biome(
|
323
|
+
start=(margin, 0),
|
324
|
+
stop=(margin + 2, 15),
|
325
|
+
object_frequencies=(0.5, 0.0, 0.25, 0.0),
|
326
|
+
),
|
327
|
+
# Oyster biome
|
328
|
+
Biome(
|
329
|
+
start=(margin + 7, 0),
|
330
|
+
stop=(margin + 9, 15),
|
331
|
+
object_frequencies=(0.0, 0.5, 0.0, 0.25),
|
332
|
+
),
|
333
|
+
)
|
334
|
+
|
268
335
|
if env_id == "ForagaxWeather-v3":
|
269
336
|
margin = aperture_size[1] // 2 + 1
|
270
337
|
width = 2 * margin + 9
|
File without changes
|
File without changes
|
File without changes
|