continual-foragax 0.30.1__py3-none-any.whl → 0.31.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {continual_foragax-0.30.1.dist-info → continual_foragax-0.31.0.dist-info}/METADATA +1 -1
- {continual_foragax-0.30.1.dist-info → continual_foragax-0.31.0.dist-info}/RECORD +8 -8
- foragax/env.py +168 -34
- foragax/objects.py +101 -1
- foragax/registry.py +24 -0
- {continual_foragax-0.30.1.dist-info → continual_foragax-0.31.0.dist-info}/WHEEL +0 -0
- {continual_foragax-0.30.1.dist-info → continual_foragax-0.31.0.dist-info}/entry_points.txt +0 -0
- {continual_foragax-0.30.1.dist-info → continual_foragax-0.31.0.dist-info}/top_level.txt +0 -0
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
foragax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
2
|
foragax/colors.py,sha256=rqNPiywP4Nvr0POhsGpasRk-nMMTS3DOwFRUgperlUk,2065
|
|
3
|
-
foragax/env.py,sha256=
|
|
4
|
-
foragax/objects.py,sha256=
|
|
5
|
-
foragax/registry.py,sha256=
|
|
3
|
+
foragax/env.py,sha256=j7u8Xv8Fz4-GZU0SXXUP40Nz4NCvRrX02h9gl8PZMK4,32456
|
|
4
|
+
foragax/objects.py,sha256=FBO0k4X6tKidWWsk6E-GotA9jXYXc8GXjZxNQJItvWI,14119
|
|
5
|
+
foragax/registry.py,sha256=zF3RtW8ssWvnDKS31xJ_Ac9hP5sv8DWy4xBu4uCbHpQ,16060
|
|
6
6
|
foragax/rendering.py,sha256=bms7wvBZTofoR-K-2QD2Ggeed7Viw8uwAEiEpEM3eSo,2768
|
|
7
7
|
foragax/weather.py,sha256=KNAiwuFz8V__6G75vZIWQKPocLzXqxXn-Vt4TbHIpcA,1258
|
|
8
8
|
foragax/data/ECA_non-blended_custom/TG_SOUID100897.txt,sha256=N7URbX6VlCZvCboUogYjMzy1I-0cfNPOn0QTLSHHfQ0,1776751
|
|
@@ -128,8 +128,8 @@ foragax/data/ECA_non-blended_custom/TG_SOUID156887.txt,sha256=juzTPgJoJxfqmZkorL
|
|
|
128
128
|
foragax/data/ECA_non-blended_custom/elements.txt,sha256=OtcUBoDAHxuln79BPKGu0tsQxG_5G2BfAX3Ck130kEA,4507
|
|
129
129
|
foragax/data/ECA_non-blended_custom/metadata.txt,sha256=nudnmOCy5cPJfSXt_IjyX0S5-T7NkCZREICZSimqeqc,48260
|
|
130
130
|
foragax/data/ECA_non-blended_custom/sources.txt,sha256=1j3lSmINAoCMqPqFrHfZJriOz6sTYZNOhXzUwvTLas0,20857
|
|
131
|
-
continual_foragax-0.
|
|
132
|
-
continual_foragax-0.
|
|
133
|
-
continual_foragax-0.
|
|
134
|
-
continual_foragax-0.
|
|
135
|
-
continual_foragax-0.
|
|
131
|
+
continual_foragax-0.31.0.dist-info/METADATA,sha256=o5l0-uCkTUx4jEV9qc49cOagYlRcksIBUkumbbPde8U,4897
|
|
132
|
+
continual_foragax-0.31.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
133
|
+
continual_foragax-0.31.0.dist-info/entry_points.txt,sha256=Qiu6iE_XudrDO_bVAMeA435h4PO9ourt8huvSHiuMPc,41
|
|
134
|
+
continual_foragax-0.31.0.dist-info/top_level.txt,sha256=-z3SDK6RfLIcLI24n8rdbeFzlVY3hunChzlu-v1Fncs,8
|
|
135
|
+
continual_foragax-0.31.0.dist-info/RECORD,,
|
foragax/env.py
CHANGED
|
@@ -62,6 +62,7 @@ class EnvState(environment.EnvState):
|
|
|
62
62
|
biome_grid: jax.Array
|
|
63
63
|
time: int
|
|
64
64
|
digestion_buffer: jax.Array
|
|
65
|
+
object_spawn_time_grid: jax.Array
|
|
65
66
|
|
|
66
67
|
|
|
67
68
|
class ForagaxEnv(environment.Environment):
|
|
@@ -114,6 +115,12 @@ class ForagaxEnv(environment.Environment):
|
|
|
114
115
|
self.reward_fns = [o.reward for o in objects]
|
|
115
116
|
self.regen_delay_fns = [o.regen_delay for o in objects]
|
|
116
117
|
self.reward_delay_fns = [o.reward_delay for o in objects]
|
|
118
|
+
self.expiry_regen_delay_fns = [o.expiry_regen_delay for o in objects]
|
|
119
|
+
|
|
120
|
+
# Expiry times per object (None becomes -1 for no expiry)
|
|
121
|
+
self.object_expiry_time = jnp.array(
|
|
122
|
+
[o.expiry_time if o.expiry_time is not None else -1 for o in objects]
|
|
123
|
+
)
|
|
117
124
|
|
|
118
125
|
# Compute reward steps per object (using max_reward_delay attribute)
|
|
119
126
|
object_max_reward_delay = jnp.array([o.max_reward_delay for o in objects])
|
|
@@ -193,6 +200,47 @@ class ForagaxEnv(environment.Environment):
|
|
|
193
200
|
max_steps_in_episode=None,
|
|
194
201
|
)
|
|
195
202
|
|
|
203
|
+
def _place_timer_at_position(
|
|
204
|
+
self, grid: jax.Array, y: int, x: int, timer_val: int
|
|
205
|
+
) -> jax.Array:
|
|
206
|
+
"""Place a timer at a specific position."""
|
|
207
|
+
return grid.at[y, x].set(timer_val)
|
|
208
|
+
|
|
209
|
+
def _place_timer_at_random_position(
|
|
210
|
+
self,
|
|
211
|
+
grid: jax.Array,
|
|
212
|
+
y: int,
|
|
213
|
+
x: int,
|
|
214
|
+
timer_val: int,
|
|
215
|
+
biome_grid: jax.Array,
|
|
216
|
+
rand_key: jax.Array,
|
|
217
|
+
) -> jax.Array:
|
|
218
|
+
"""Place a timer at a random valid position within the same biome."""
|
|
219
|
+
# Set the original position to empty temporarily
|
|
220
|
+
grid_temp = grid.at[y, x].set(0)
|
|
221
|
+
|
|
222
|
+
# Find all valid spawn locations (empty cells within the same biome)
|
|
223
|
+
biome_id = biome_grid[y, x]
|
|
224
|
+
biome_mask = biome_grid == biome_id
|
|
225
|
+
empty_mask = grid_temp == 0
|
|
226
|
+
valid_spawn_mask = biome_mask & empty_mask
|
|
227
|
+
|
|
228
|
+
num_valid_spawns = jnp.sum(valid_spawn_mask)
|
|
229
|
+
|
|
230
|
+
# Get indices of valid spawn locations, padded to a static size
|
|
231
|
+
y_indices, x_indices = jnp.nonzero(
|
|
232
|
+
valid_spawn_mask, size=self.size[0] * self.size[1], fill_value=-1
|
|
233
|
+
)
|
|
234
|
+
valid_spawn_indices = jnp.stack([y_indices, x_indices], axis=1)
|
|
235
|
+
|
|
236
|
+
# Select a random valid location
|
|
237
|
+
random_idx = jax.random.randint(rand_key, (), 0, num_valid_spawns)
|
|
238
|
+
new_spawn_pos = valid_spawn_indices[random_idx]
|
|
239
|
+
|
|
240
|
+
# Place the timer at the new random position
|
|
241
|
+
new_grid = grid_temp.at[new_spawn_pos[0], new_spawn_pos[1]].set(timer_val)
|
|
242
|
+
return new_grid
|
|
243
|
+
|
|
196
244
|
def step_env(
|
|
197
245
|
self,
|
|
198
246
|
key: jax.Array,
|
|
@@ -277,54 +325,129 @@ class ForagaxEnv(environment.Environment):
|
|
|
277
325
|
is_timer, state.object_grid + num_obj_types, state.object_grid
|
|
278
326
|
)
|
|
279
327
|
|
|
328
|
+
# Track which objects just respawned (timer -> object transition)
|
|
329
|
+
# An object respawned if it was negative and is now positive
|
|
330
|
+
just_respawned = is_timer & (object_grid > 0)
|
|
331
|
+
|
|
332
|
+
# Update spawn times for objects that just respawned
|
|
333
|
+
object_spawn_time_grid = jnp.where(
|
|
334
|
+
just_respawned, state.time, state.object_spawn_time_grid
|
|
335
|
+
)
|
|
336
|
+
|
|
280
337
|
# Collect object: set a timer
|
|
281
338
|
regen_delay = jax.lax.switch(
|
|
282
339
|
obj_at_pos, self.regen_delay_fns, state.time, regen_subkey
|
|
283
340
|
)
|
|
284
341
|
encoded_timer = obj_at_pos - ((regen_delay + 1) * num_obj_types)
|
|
285
342
|
|
|
286
|
-
def place_at_current_pos(current_grid, timer_val):
|
|
287
|
-
return current_grid.at[pos[1], pos[0]].set(timer_val)
|
|
288
|
-
|
|
289
|
-
def place_at_random_pos(current_grid, timer_val):
|
|
290
|
-
# Set the collected position to empty temporarily
|
|
291
|
-
grid = current_grid.at[pos[1], pos[0]].set(0)
|
|
292
|
-
|
|
293
|
-
# Find all valid spawn locations (empty cells within the same biome)
|
|
294
|
-
biome_id = state.biome_grid[pos[1], pos[0]]
|
|
295
|
-
biome_mask = state.biome_grid == biome_id
|
|
296
|
-
empty_mask = grid == 0
|
|
297
|
-
valid_spawn_mask = biome_mask & empty_mask
|
|
298
|
-
|
|
299
|
-
num_valid_spawns = jnp.sum(valid_spawn_mask)
|
|
300
|
-
|
|
301
|
-
# Get indices of valid spawn locations, padded to a static size
|
|
302
|
-
y_indices, x_indices = jnp.nonzero(
|
|
303
|
-
valid_spawn_mask, size=self.size[0] * self.size[1], fill_value=-1
|
|
304
|
-
)
|
|
305
|
-
valid_spawn_indices = jnp.stack([y_indices, x_indices], axis=1)
|
|
306
|
-
|
|
307
|
-
# Select a random valid location
|
|
308
|
-
random_idx = jax.random.randint(rand_key, (), 0, num_valid_spawns)
|
|
309
|
-
new_spawn_pos = valid_spawn_indices[random_idx]
|
|
310
|
-
|
|
311
|
-
# Place the timer at the new random position
|
|
312
|
-
return grid.at[new_spawn_pos[0], new_spawn_pos[1]].set(timer_val)
|
|
313
|
-
|
|
314
343
|
# If collected, replace object with timer; otherwise, keep it
|
|
315
344
|
val_at_pos = object_grid[pos[1], pos[0]]
|
|
316
345
|
should_collect = is_collectable & (val_at_pos > 0)
|
|
317
346
|
|
|
318
347
|
# When not collecting, the value at the position remains unchanged.
|
|
319
348
|
# When collecting, we either place the timer at the current position or a random one.
|
|
349
|
+
def do_collection():
|
|
350
|
+
return jax.lax.cond(
|
|
351
|
+
self.object_random_respawn[obj_at_pos],
|
|
352
|
+
lambda: self._place_timer_at_random_position(
|
|
353
|
+
object_grid,
|
|
354
|
+
pos[1],
|
|
355
|
+
pos[0],
|
|
356
|
+
encoded_timer,
|
|
357
|
+
state.biome_grid,
|
|
358
|
+
rand_key,
|
|
359
|
+
),
|
|
360
|
+
lambda: self._place_timer_at_position(
|
|
361
|
+
object_grid, pos[1], pos[0], encoded_timer
|
|
362
|
+
),
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
def no_collection():
|
|
366
|
+
return object_grid
|
|
367
|
+
|
|
320
368
|
object_grid = jax.lax.cond(
|
|
321
369
|
should_collect,
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
370
|
+
do_collection,
|
|
371
|
+
no_collection,
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
# 3.5. HANDLE OBJECT EXPIRY
|
|
375
|
+
# Check each cell for objects that have exceeded their expiry time
|
|
376
|
+
current_objects_for_expiry = jnp.maximum(0, object_grid)
|
|
377
|
+
|
|
378
|
+
# Calculate age of each object (current_time - spawn_time)
|
|
379
|
+
object_ages = state.time - object_spawn_time_grid
|
|
380
|
+
|
|
381
|
+
# Get expiry time for each object type in the grid
|
|
382
|
+
expiry_times = self.object_expiry_time[current_objects_for_expiry]
|
|
383
|
+
|
|
384
|
+
# Check if object should expire (age >= expiry_time and expiry_time >= 0)
|
|
385
|
+
should_expire = (
|
|
386
|
+
(object_ages >= expiry_times)
|
|
387
|
+
& (expiry_times >= 0)
|
|
388
|
+
& (current_objects_for_expiry > 0)
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
# For expired objects, calculate expiry regen delay
|
|
392
|
+
key, expiry_key = jax.random.split(key)
|
|
393
|
+
|
|
394
|
+
# Process expiry for all positions that need it
|
|
395
|
+
def process_expiry(y, x, grid, spawn_grid, key):
|
|
396
|
+
obj_id = current_objects_for_expiry[y, x]
|
|
397
|
+
should_exp = should_expire[y, x]
|
|
398
|
+
|
|
399
|
+
def expire_object():
|
|
400
|
+
# Get expiry regen delay for this object
|
|
401
|
+
exp_key = jax.random.fold_in(key, y * self.size[0] + x)
|
|
402
|
+
exp_delay = jax.lax.switch(
|
|
403
|
+
obj_id, self.expiry_regen_delay_fns, state.time, exp_key
|
|
404
|
+
)
|
|
405
|
+
encoded_exp_timer = obj_id - ((exp_delay + 1) * num_obj_types)
|
|
406
|
+
|
|
407
|
+
# Check if this object should respawn randomly
|
|
408
|
+
should_random_respawn = self.object_random_respawn[obj_id]
|
|
409
|
+
|
|
410
|
+
# Use second split for randomness in random placement
|
|
411
|
+
rand_key = jax.random.split(exp_key)[1]
|
|
412
|
+
|
|
413
|
+
# Place timer either at current position or random position
|
|
414
|
+
new_grid = jax.lax.cond(
|
|
415
|
+
should_random_respawn,
|
|
416
|
+
lambda: self._place_timer_at_random_position(
|
|
417
|
+
grid, y, x, encoded_exp_timer, state.biome_grid, rand_key
|
|
418
|
+
),
|
|
419
|
+
lambda: self._place_timer_at_position(
|
|
420
|
+
grid, y, x, encoded_exp_timer
|
|
421
|
+
),
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
return new_grid, spawn_grid
|
|
425
|
+
|
|
426
|
+
def no_expire():
|
|
427
|
+
return grid, spawn_grid
|
|
428
|
+
|
|
429
|
+
return jax.lax.cond(should_exp, expire_object, no_expire)
|
|
430
|
+
|
|
431
|
+
# Apply expiry to all cells (vectorized)
|
|
432
|
+
def scan_expiry_row(carry, y):
|
|
433
|
+
grid, spawn_grid, key = carry
|
|
434
|
+
|
|
435
|
+
def scan_expiry_col(carry_col, x):
|
|
436
|
+
grid_col, spawn_grid_col, key_col = carry_col
|
|
437
|
+
grid_col, spawn_grid_col = process_expiry(
|
|
438
|
+
y, x, grid_col, spawn_grid_col, key_col
|
|
439
|
+
)
|
|
440
|
+
return (grid_col, spawn_grid_col, key_col), None
|
|
441
|
+
|
|
442
|
+
(grid, spawn_grid, key), _ = jax.lax.scan(
|
|
443
|
+
scan_expiry_col, (grid, spawn_grid, key), jnp.arange(self.size[0])
|
|
444
|
+
)
|
|
445
|
+
return (grid, spawn_grid, key), None
|
|
446
|
+
|
|
447
|
+
(object_grid, object_spawn_time_grid, _), _ = jax.lax.scan(
|
|
448
|
+
scan_expiry_row,
|
|
449
|
+
(object_grid, object_spawn_time_grid, expiry_key),
|
|
450
|
+
jnp.arange(self.size[1]),
|
|
328
451
|
)
|
|
329
452
|
|
|
330
453
|
info = {"discount": self.discount(state, params)}
|
|
@@ -345,6 +468,7 @@ class ForagaxEnv(environment.Environment):
|
|
|
345
468
|
biome_grid=state.biome_grid,
|
|
346
469
|
time=state.time + 1,
|
|
347
470
|
digestion_buffer=digestion_buffer,
|
|
471
|
+
object_spawn_time_grid=object_spawn_time_grid,
|
|
348
472
|
)
|
|
349
473
|
|
|
350
474
|
done = self.is_terminal(state, params)
|
|
@@ -378,12 +502,16 @@ class ForagaxEnv(environment.Environment):
|
|
|
378
502
|
# Place agent in the center of the world
|
|
379
503
|
agent_pos = jnp.array([self.size[0] // 2, self.size[1] // 2])
|
|
380
504
|
|
|
505
|
+
# Initialize spawn times to 0 (all objects spawn at time 0)
|
|
506
|
+
object_spawn_time_grid = jnp.zeros((self.size[1], self.size[0]), dtype=int)
|
|
507
|
+
|
|
381
508
|
state = EnvState(
|
|
382
509
|
pos=agent_pos,
|
|
383
510
|
object_grid=object_grid,
|
|
384
511
|
biome_grid=biome_grid,
|
|
385
512
|
time=0,
|
|
386
513
|
digestion_buffer=jnp.zeros((self.max_reward_delay,)),
|
|
514
|
+
object_spawn_time_grid=object_spawn_time_grid,
|
|
387
515
|
)
|
|
388
516
|
|
|
389
517
|
return self.get_obs(state, params), state
|
|
@@ -450,6 +578,12 @@ class ForagaxEnv(environment.Environment):
|
|
|
450
578
|
(self.max_reward_delay,),
|
|
451
579
|
float,
|
|
452
580
|
),
|
|
581
|
+
"object_spawn_time_grid": spaces.Box(
|
|
582
|
+
0,
|
|
583
|
+
jnp.inf,
|
|
584
|
+
(self.size[1], self.size[0]),
|
|
585
|
+
int,
|
|
586
|
+
),
|
|
453
587
|
}
|
|
454
588
|
)
|
|
455
589
|
|
foragax/objects.py
CHANGED
|
@@ -18,6 +18,7 @@ class BaseForagaxObject:
|
|
|
18
18
|
color: Tuple[int, int, int] = (0, 0, 0),
|
|
19
19
|
random_respawn: bool = False,
|
|
20
20
|
max_reward_delay: int = 0,
|
|
21
|
+
expiry_time: Optional[int] = None,
|
|
21
22
|
):
|
|
22
23
|
self.name = name
|
|
23
24
|
self.blocking = blocking
|
|
@@ -25,6 +26,7 @@ class BaseForagaxObject:
|
|
|
25
26
|
self.color = color
|
|
26
27
|
self.random_respawn = random_respawn
|
|
27
28
|
self.max_reward_delay = max_reward_delay
|
|
29
|
+
self.expiry_time = expiry_time
|
|
28
30
|
|
|
29
31
|
@abc.abstractmethod
|
|
30
32
|
def reward(self, clock: int, rng: jax.Array) -> float:
|
|
@@ -36,6 +38,11 @@ class BaseForagaxObject:
|
|
|
36
38
|
"""Reward delay function."""
|
|
37
39
|
raise NotImplementedError
|
|
38
40
|
|
|
41
|
+
@abc.abstractmethod
|
|
42
|
+
def expiry_regen_delay(self, clock: int, rng: jax.Array) -> int:
|
|
43
|
+
"""Expiry regeneration delay function."""
|
|
44
|
+
raise NotImplementedError
|
|
45
|
+
|
|
39
46
|
|
|
40
47
|
class DefaultForagaxObject(BaseForagaxObject):
|
|
41
48
|
"""Base class for default objects in the Foragax environment."""
|
|
@@ -51,15 +58,24 @@ class DefaultForagaxObject(BaseForagaxObject):
|
|
|
51
58
|
random_respawn: bool = False,
|
|
52
59
|
reward_delay: int = 0,
|
|
53
60
|
max_reward_delay: Optional[int] = None,
|
|
61
|
+
expiry_time: Optional[int] = None,
|
|
62
|
+
expiry_regen_delay: Tuple[int, int] = (10, 100),
|
|
54
63
|
):
|
|
55
64
|
if max_reward_delay is None:
|
|
56
65
|
max_reward_delay = reward_delay
|
|
57
66
|
super().__init__(
|
|
58
|
-
name,
|
|
67
|
+
name,
|
|
68
|
+
blocking,
|
|
69
|
+
collectable,
|
|
70
|
+
color,
|
|
71
|
+
random_respawn,
|
|
72
|
+
max_reward_delay,
|
|
73
|
+
expiry_time,
|
|
59
74
|
)
|
|
60
75
|
self.reward_val = reward
|
|
61
76
|
self.regen_delay_range = regen_delay
|
|
62
77
|
self.reward_delay_val = reward_delay
|
|
78
|
+
self.expiry_regen_delay_range = expiry_regen_delay
|
|
63
79
|
|
|
64
80
|
def reward(self, clock: int, rng: jax.Array) -> float:
|
|
65
81
|
"""Default reward function."""
|
|
@@ -74,6 +90,11 @@ class DefaultForagaxObject(BaseForagaxObject):
|
|
|
74
90
|
"""Default reward delay function."""
|
|
75
91
|
return self.reward_delay_val
|
|
76
92
|
|
|
93
|
+
def expiry_regen_delay(self, clock: int, rng: jax.Array) -> int:
|
|
94
|
+
"""Default expiry regeneration delay function."""
|
|
95
|
+
min_delay, max_delay = self.expiry_regen_delay_range
|
|
96
|
+
return jax.random.randint(rng, (), min_delay, max_delay)
|
|
97
|
+
|
|
77
98
|
|
|
78
99
|
class NormalRegenForagaxObject(DefaultForagaxObject):
|
|
79
100
|
"""Object with regeneration delay from a normal distribution."""
|
|
@@ -89,7 +110,16 @@ class NormalRegenForagaxObject(DefaultForagaxObject):
|
|
|
89
110
|
random_respawn: bool = False,
|
|
90
111
|
reward_delay: int = 0,
|
|
91
112
|
max_reward_delay: Optional[int] = None,
|
|
113
|
+
expiry_time: Optional[int] = None,
|
|
114
|
+
mean_expiry_regen_delay: Optional[int] = None,
|
|
115
|
+
std_expiry_regen_delay: Optional[int] = None,
|
|
92
116
|
):
|
|
117
|
+
# If expiry regen delays not provided, use same as normal regen
|
|
118
|
+
if mean_expiry_regen_delay is None:
|
|
119
|
+
mean_expiry_regen_delay = mean_regen_delay
|
|
120
|
+
if std_expiry_regen_delay is None:
|
|
121
|
+
std_expiry_regen_delay = std_regen_delay
|
|
122
|
+
|
|
93
123
|
super().__init__(
|
|
94
124
|
name=name,
|
|
95
125
|
reward=reward,
|
|
@@ -99,15 +129,27 @@ class NormalRegenForagaxObject(DefaultForagaxObject):
|
|
|
99
129
|
random_respawn=random_respawn,
|
|
100
130
|
reward_delay=reward_delay,
|
|
101
131
|
max_reward_delay=max_reward_delay,
|
|
132
|
+
expiry_time=expiry_time,
|
|
133
|
+
expiry_regen_delay=(mean_expiry_regen_delay, mean_expiry_regen_delay),
|
|
102
134
|
)
|
|
103
135
|
self.mean_regen_delay = mean_regen_delay
|
|
104
136
|
self.std_regen_delay = std_regen_delay
|
|
137
|
+
self.mean_expiry_regen_delay = mean_expiry_regen_delay
|
|
138
|
+
self.std_expiry_regen_delay = std_expiry_regen_delay
|
|
105
139
|
|
|
106
140
|
def regen_delay(self, clock: int, rng: jax.Array) -> int:
|
|
107
141
|
"""Regeneration delay from a normal distribution."""
|
|
108
142
|
delay = self.mean_regen_delay + jax.random.normal(rng) * self.std_regen_delay
|
|
109
143
|
return jnp.maximum(0, delay).astype(int)
|
|
110
144
|
|
|
145
|
+
def expiry_regen_delay(self, clock: int, rng: jax.Array) -> int:
|
|
146
|
+
"""Expiry regeneration delay from a normal distribution."""
|
|
147
|
+
delay = (
|
|
148
|
+
self.mean_expiry_regen_delay
|
|
149
|
+
+ jax.random.normal(rng) * self.std_expiry_regen_delay
|
|
150
|
+
)
|
|
151
|
+
return jnp.maximum(0, delay).astype(int)
|
|
152
|
+
|
|
111
153
|
|
|
112
154
|
class WeatherObject(NormalRegenForagaxObject):
|
|
113
155
|
"""Object with reward based on temperature data."""
|
|
@@ -124,6 +166,9 @@ class WeatherObject(NormalRegenForagaxObject):
|
|
|
124
166
|
random_respawn: bool = False,
|
|
125
167
|
reward_delay: int = 0,
|
|
126
168
|
max_reward_delay: Optional[int] = None,
|
|
169
|
+
expiry_time: Optional[int] = None,
|
|
170
|
+
mean_expiry_regen_delay: Optional[int] = None,
|
|
171
|
+
std_expiry_regen_delay: Optional[int] = None,
|
|
127
172
|
):
|
|
128
173
|
super().__init__(
|
|
129
174
|
name=name,
|
|
@@ -134,6 +179,9 @@ class WeatherObject(NormalRegenForagaxObject):
|
|
|
134
179
|
random_respawn=random_respawn,
|
|
135
180
|
reward_delay=reward_delay,
|
|
136
181
|
max_reward_delay=max_reward_delay,
|
|
182
|
+
expiry_time=expiry_time,
|
|
183
|
+
mean_expiry_regen_delay=mean_expiry_regen_delay,
|
|
184
|
+
std_expiry_regen_delay=std_expiry_regen_delay,
|
|
137
185
|
)
|
|
138
186
|
self.rewards = rewards * multiplier
|
|
139
187
|
self.repeat = repeat
|
|
@@ -332,6 +380,44 @@ GREEN_FAKE_UNIFORM_RANDOM = DefaultForagaxObject(
|
|
|
332
380
|
random_respawn=True,
|
|
333
381
|
)
|
|
334
382
|
|
|
383
|
+
# Random respawn variants with expiry
|
|
384
|
+
BROWN_MOREL_UNIFORM_RANDOM_EXPIRY = DefaultForagaxObject(
|
|
385
|
+
name="brown_morel",
|
|
386
|
+
reward=10.0,
|
|
387
|
+
collectable=True,
|
|
388
|
+
color=(63, 30, 25),
|
|
389
|
+
regen_delay=(90, 110),
|
|
390
|
+
random_respawn=True,
|
|
391
|
+
expiry_time=500,
|
|
392
|
+
)
|
|
393
|
+
BROWN_OYSTER_UNIFORM_RANDOM_EXPIRY = DefaultForagaxObject(
|
|
394
|
+
name="brown_oyster",
|
|
395
|
+
reward=1.0,
|
|
396
|
+
collectable=True,
|
|
397
|
+
color=(63, 30, 25),
|
|
398
|
+
regen_delay=(9, 11),
|
|
399
|
+
random_respawn=True,
|
|
400
|
+
expiry_time=500,
|
|
401
|
+
)
|
|
402
|
+
GREEN_DEATHCAP_UNIFORM_RANDOM_EXPIRY = DefaultForagaxObject(
|
|
403
|
+
name="green_deathcap",
|
|
404
|
+
reward=-5.0,
|
|
405
|
+
collectable=True,
|
|
406
|
+
color=(0, 255, 0),
|
|
407
|
+
regen_delay=(9, 11),
|
|
408
|
+
random_respawn=True,
|
|
409
|
+
expiry_time=500,
|
|
410
|
+
)
|
|
411
|
+
GREEN_FAKE_UNIFORM_RANDOM_EXPIRY = DefaultForagaxObject(
|
|
412
|
+
name="green_fake",
|
|
413
|
+
reward=0.0,
|
|
414
|
+
collectable=True,
|
|
415
|
+
color=(0, 255, 0),
|
|
416
|
+
regen_delay=(9, 11),
|
|
417
|
+
random_respawn=True,
|
|
418
|
+
expiry_time=500,
|
|
419
|
+
)
|
|
420
|
+
|
|
335
421
|
|
|
336
422
|
def create_weather_objects(
|
|
337
423
|
file_index: int = 0,
|
|
@@ -340,6 +426,9 @@ def create_weather_objects(
|
|
|
340
426
|
same_color: bool = False,
|
|
341
427
|
random_respawn: bool = False,
|
|
342
428
|
reward_delay: int = 0,
|
|
429
|
+
expiry_time: Optional[int] = None,
|
|
430
|
+
mean_expiry_regen_delay: Optional[int] = None,
|
|
431
|
+
std_expiry_regen_delay: Optional[int] = None,
|
|
343
432
|
):
|
|
344
433
|
"""Create HOT and COLD WeatherObject instances using the specified file.
|
|
345
434
|
|
|
@@ -348,6 +437,11 @@ def create_weather_objects(
|
|
|
348
437
|
repeat: How many steps each temperature value repeats for.
|
|
349
438
|
multiplier: Base multiplier applied to HOT; COLD will use -multiplier.
|
|
350
439
|
same_color: If True, both HOT and COLD use the same color.
|
|
440
|
+
random_respawn: If True, objects respawn at random locations.
|
|
441
|
+
reward_delay: Number of steps before reward is delivered.
|
|
442
|
+
expiry_time: Time steps before object expires (None = no expiry).
|
|
443
|
+
mean_expiry_regen_delay: Mean delay for expiry respawn.
|
|
444
|
+
std_expiry_regen_delay: Standard deviation for expiry respawn delay.
|
|
351
445
|
|
|
352
446
|
Returns:
|
|
353
447
|
A tuple (HOT, COLD) of WeatherObject instances.
|
|
@@ -370,6 +464,9 @@ def create_weather_objects(
|
|
|
370
464
|
color=hot_color,
|
|
371
465
|
random_respawn=random_respawn,
|
|
372
466
|
reward_delay=reward_delay,
|
|
467
|
+
expiry_time=expiry_time,
|
|
468
|
+
mean_expiry_regen_delay=mean_expiry_regen_delay,
|
|
469
|
+
std_expiry_regen_delay=std_expiry_regen_delay,
|
|
373
470
|
)
|
|
374
471
|
|
|
375
472
|
cold_color = hot_color if same_color else (0, 255, 255)
|
|
@@ -381,6 +478,9 @@ def create_weather_objects(
|
|
|
381
478
|
color=cold_color,
|
|
382
479
|
random_respawn=random_respawn,
|
|
383
480
|
reward_delay=reward_delay,
|
|
481
|
+
expiry_time=expiry_time,
|
|
482
|
+
mean_expiry_regen_delay=mean_expiry_regen_delay,
|
|
483
|
+
std_expiry_regen_delay=std_expiry_regen_delay,
|
|
384
484
|
)
|
|
385
485
|
|
|
386
486
|
return hot, cold
|
foragax/registry.py
CHANGED
|
@@ -12,18 +12,22 @@ from foragax.objects import (
|
|
|
12
12
|
BROWN_MOREL_2,
|
|
13
13
|
BROWN_MOREL_UNIFORM,
|
|
14
14
|
BROWN_MOREL_UNIFORM_RANDOM,
|
|
15
|
+
BROWN_MOREL_UNIFORM_RANDOM_EXPIRY,
|
|
15
16
|
BROWN_OYSTER,
|
|
16
17
|
BROWN_OYSTER_UNIFORM,
|
|
17
18
|
BROWN_OYSTER_UNIFORM_RANDOM,
|
|
19
|
+
BROWN_OYSTER_UNIFORM_RANDOM_EXPIRY,
|
|
18
20
|
GREEN_DEATHCAP,
|
|
19
21
|
GREEN_DEATHCAP_2,
|
|
20
22
|
GREEN_DEATHCAP_3,
|
|
21
23
|
GREEN_DEATHCAP_UNIFORM,
|
|
22
24
|
GREEN_DEATHCAP_UNIFORM_RANDOM,
|
|
25
|
+
GREEN_DEATHCAP_UNIFORM_RANDOM_EXPIRY,
|
|
23
26
|
GREEN_FAKE,
|
|
24
27
|
GREEN_FAKE_2,
|
|
25
28
|
GREEN_FAKE_UNIFORM,
|
|
26
29
|
GREEN_FAKE_UNIFORM_RANDOM,
|
|
30
|
+
GREEN_FAKE_UNIFORM_RANDOM_EXPIRY,
|
|
27
31
|
LARGE_MOREL,
|
|
28
32
|
LARGE_OYSTER,
|
|
29
33
|
MEDIUM_MOREL,
|
|
@@ -304,6 +308,26 @@ ENV_CONFIGS: Dict[str, Dict[str, Any]] = {
|
|
|
304
308
|
"nowrap": True,
|
|
305
309
|
"deterministic_spawn": True,
|
|
306
310
|
},
|
|
311
|
+
"ForagaxTwoBiome-v17": {
|
|
312
|
+
"size": (15, 15),
|
|
313
|
+
"aperture_size": None,
|
|
314
|
+
"objects": (
|
|
315
|
+
BROWN_MOREL_UNIFORM_RANDOM_EXPIRY,
|
|
316
|
+
BROWN_OYSTER_UNIFORM_RANDOM_EXPIRY,
|
|
317
|
+
GREEN_DEATHCAP_UNIFORM_RANDOM_EXPIRY,
|
|
318
|
+
GREEN_FAKE_UNIFORM_RANDOM_EXPIRY,
|
|
319
|
+
),
|
|
320
|
+
"biomes": (
|
|
321
|
+
# Morel biome
|
|
322
|
+
Biome(start=(3, 0), stop=(5, 15), object_frequencies=(0.25, 0.0, 0.5, 0.0)),
|
|
323
|
+
# Oyster biome
|
|
324
|
+
Biome(
|
|
325
|
+
start=(10, 0), stop=(12, 15), object_frequencies=(0.0, 0.25, 0.0, 0.5)
|
|
326
|
+
),
|
|
327
|
+
),
|
|
328
|
+
"nowrap": False,
|
|
329
|
+
"deterministic_spawn": True,
|
|
330
|
+
},
|
|
307
331
|
"ForagaxTwoBiomeSmall-v1": {
|
|
308
332
|
"size": (16, 8),
|
|
309
333
|
"aperture_size": None,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|