continual-foragax 0.28.1__py3-none-any.whl → 0.30.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {continual_foragax-0.28.1.dist-info → continual_foragax-0.30.0.dist-info}/METADATA +1 -1
- {continual_foragax-0.28.1.dist-info → continual_foragax-0.30.0.dist-info}/RECORD +8 -8
- foragax/env.py +51 -13
- foragax/objects.py +28 -4
- foragax/registry.py +14 -8
- {continual_foragax-0.28.1.dist-info → continual_foragax-0.30.0.dist-info}/WHEEL +0 -0
- {continual_foragax-0.28.1.dist-info → continual_foragax-0.30.0.dist-info}/entry_points.txt +0 -0
- {continual_foragax-0.28.1.dist-info → continual_foragax-0.30.0.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,8 @@
|
|
1
1
|
foragax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
foragax/colors.py,sha256=rqNPiywP4Nvr0POhsGpasRk-nMMTS3DOwFRUgperlUk,2065
|
3
|
-
foragax/env.py,sha256=
|
4
|
-
foragax/objects.py,sha256=
|
5
|
-
foragax/registry.py,sha256=
|
3
|
+
foragax/env.py,sha256=4NZ5JsUGjAepmzw2uxu5_ikyVZnZ7vazy062Xzx22Zg,27481
|
4
|
+
foragax/objects.py,sha256=0vb_iyr62BKaIxiE3JwtRhZhFE3VFM6PdxDZTaDtv24,10410
|
5
|
+
foragax/registry.py,sha256=Dxg6cWIPwg91fNrCPxADJv35u6jFg_8dI5iTpCMFEFA,15229
|
6
6
|
foragax/rendering.py,sha256=bms7wvBZTofoR-K-2QD2Ggeed7Viw8uwAEiEpEM3eSo,2768
|
7
7
|
foragax/weather.py,sha256=KNAiwuFz8V__6G75vZIWQKPocLzXqxXn-Vt4TbHIpcA,1258
|
8
8
|
foragax/data/ECA_non-blended_custom/TG_SOUID100897.txt,sha256=N7URbX6VlCZvCboUogYjMzy1I-0cfNPOn0QTLSHHfQ0,1776751
|
@@ -128,8 +128,8 @@ foragax/data/ECA_non-blended_custom/TG_SOUID156887.txt,sha256=juzTPgJoJxfqmZkorL
|
|
128
128
|
foragax/data/ECA_non-blended_custom/elements.txt,sha256=OtcUBoDAHxuln79BPKGu0tsQxG_5G2BfAX3Ck130kEA,4507
|
129
129
|
foragax/data/ECA_non-blended_custom/metadata.txt,sha256=nudnmOCy5cPJfSXt_IjyX0S5-T7NkCZREICZSimqeqc,48260
|
130
130
|
foragax/data/ECA_non-blended_custom/sources.txt,sha256=1j3lSmINAoCMqPqFrHfZJriOz6sTYZNOhXzUwvTLas0,20857
|
131
|
-
continual_foragax-0.
|
132
|
-
continual_foragax-0.
|
133
|
-
continual_foragax-0.
|
134
|
-
continual_foragax-0.
|
135
|
-
continual_foragax-0.
|
131
|
+
continual_foragax-0.30.0.dist-info/METADATA,sha256=d0xeSz0BvDVe1lOUGdhVyqnbkkYN7dNW4BPfCnDSZfQ,4897
|
132
|
+
continual_foragax-0.30.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
133
|
+
continual_foragax-0.30.0.dist-info/entry_points.txt,sha256=Qiu6iE_XudrDO_bVAMeA435h4PO9ourt8huvSHiuMPc,41
|
134
|
+
continual_foragax-0.30.0.dist-info/top_level.txt,sha256=-z3SDK6RfLIcLI24n8rdbeFzlVY3hunChzlu-v1Fncs,8
|
135
|
+
continual_foragax-0.30.0.dist-info/RECORD,,
|
foragax/env.py
CHANGED
@@ -61,6 +61,7 @@ class EnvState(environment.EnvState):
|
|
61
61
|
object_grid: jax.Array
|
62
62
|
biome_grid: jax.Array
|
63
63
|
time: int
|
64
|
+
digestion_buffer: jax.Array
|
64
65
|
|
65
66
|
|
66
67
|
class ForagaxEnv(environment.Environment):
|
@@ -102,11 +103,6 @@ class ForagaxEnv(environment.Environment):
|
|
102
103
|
if self.nowrap and not self.full_world:
|
103
104
|
objects = objects + (PADDING,)
|
104
105
|
self.objects = objects
|
105
|
-
self.weather_object = None
|
106
|
-
for o in objects:
|
107
|
-
if isinstance(o, WeatherObject):
|
108
|
-
self.weather_object = o
|
109
|
-
break
|
110
106
|
|
111
107
|
# JIT-compatible versions of object and biome properties
|
112
108
|
self.object_ids = jnp.arange(len(objects))
|
@@ -117,6 +113,13 @@ class ForagaxEnv(environment.Environment):
|
|
117
113
|
|
118
114
|
self.reward_fns = [o.reward for o in objects]
|
119
115
|
self.regen_delay_fns = [o.regen_delay for o in objects]
|
116
|
+
self.reward_delay_fns = [o.reward_delay for o in objects]
|
117
|
+
|
118
|
+
# Compute reward steps per object (using max_reward_delay attribute)
|
119
|
+
object_max_reward_delay = jnp.array([o.max_reward_delay for o in objects])
|
120
|
+
self.max_reward_delay = (
|
121
|
+
int(jnp.max(object_max_reward_delay)) + 1 if len(objects) > 0 else 0
|
122
|
+
)
|
120
123
|
|
121
124
|
self.biome_object_frequencies = jnp.array(
|
122
125
|
[b.object_frequencies for b in biomes]
|
@@ -237,12 +240,36 @@ class ForagaxEnv(environment.Environment):
|
|
237
240
|
|
238
241
|
# 2. HANDLE COLLISIONS AND REWARDS
|
239
242
|
obj_at_pos = current_objects[pos[1], pos[0]]
|
240
|
-
key, subkey = jax.random.split(key)
|
241
|
-
reward = jax.lax.switch(obj_at_pos, self.reward_fns, state.time, subkey)
|
242
243
|
is_collectable = self.object_collectable[obj_at_pos]
|
244
|
+
should_collect = is_collectable & (obj_at_pos > 0)
|
245
|
+
|
246
|
+
# Handle digestion: add reward to buffer if collected
|
247
|
+
digestion_buffer = state.digestion_buffer
|
248
|
+
key, reward_subkey = jax.random.split(key)
|
249
|
+
object_reward = jax.lax.switch(
|
250
|
+
obj_at_pos, self.reward_fns, state.time, reward_subkey
|
251
|
+
)
|
252
|
+
key, digestion_subkey = jax.random.split(key)
|
253
|
+
reward_delay = jax.lax.switch(
|
254
|
+
obj_at_pos, self.reward_delay_fns, state.time, digestion_subkey
|
255
|
+
)
|
256
|
+
reward = jnp.where(should_collect & (reward_delay == 0), object_reward, 0.0)
|
257
|
+
if self.max_reward_delay > 0:
|
258
|
+
# Add delayed rewards to buffer
|
259
|
+
digestion_buffer = jax.lax.cond(
|
260
|
+
should_collect & (reward_delay > 0),
|
261
|
+
lambda: digestion_buffer.at[
|
262
|
+
(state.time + reward_delay) % self.max_reward_delay
|
263
|
+
].add(object_reward),
|
264
|
+
lambda: digestion_buffer,
|
265
|
+
)
|
266
|
+
# Deliver current rewards
|
267
|
+
current_index = state.time % self.max_reward_delay
|
268
|
+
reward += digestion_buffer[current_index]
|
269
|
+
digestion_buffer = digestion_buffer.at[current_index].set(0.0)
|
243
270
|
|
244
271
|
# 3. HANDLE OBJECT COLLECTION AND RESPAWNING
|
245
|
-
key,
|
272
|
+
key, regen_subkey, rand_key = jax.random.split(key, 3)
|
246
273
|
|
247
274
|
# Decrement timers (stored as negative values)
|
248
275
|
is_timer = state.object_grid < 0
|
@@ -252,7 +279,7 @@ class ForagaxEnv(environment.Environment):
|
|
252
279
|
|
253
280
|
# Collect object: set a timer
|
254
281
|
regen_delay = jax.lax.switch(
|
255
|
-
obj_at_pos, self.regen_delay_fns, state.time,
|
282
|
+
obj_at_pos, self.regen_delay_fns, state.time, regen_subkey
|
256
283
|
)
|
257
284
|
encoded_timer = obj_at_pos - ((regen_delay + 1) * num_obj_types)
|
258
285
|
|
@@ -301,10 +328,13 @@ class ForagaxEnv(environment.Environment):
|
|
301
328
|
)
|
302
329
|
|
303
330
|
info = {"discount": self.discount(state, params)}
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
331
|
+
temperatures = jnp.zeros(len(self.objects))
|
332
|
+
for obj_index, obj in enumerate(self.objects):
|
333
|
+
if isinstance(obj, WeatherObject):
|
334
|
+
temperatures = temperatures.at[obj_index].set(
|
335
|
+
get_temperature(obj.rewards, state.time, obj.repeat)
|
336
|
+
)
|
337
|
+
info["temperatures"] = temperatures
|
308
338
|
info["biome_id"] = state.biome_grid[pos[1], pos[0]]
|
309
339
|
info["object_collected_id"] = jax.lax.select(should_collect, obj_at_pos, -1)
|
310
340
|
|
@@ -314,6 +344,7 @@ class ForagaxEnv(environment.Environment):
|
|
314
344
|
object_grid=object_grid,
|
315
345
|
biome_grid=state.biome_grid,
|
316
346
|
time=state.time + 1,
|
347
|
+
digestion_buffer=digestion_buffer,
|
317
348
|
)
|
318
349
|
|
319
350
|
done = self.is_terminal(state, params)
|
@@ -352,6 +383,7 @@ class ForagaxEnv(environment.Environment):
|
|
352
383
|
object_grid=object_grid,
|
353
384
|
biome_grid=biome_grid,
|
354
385
|
time=0,
|
386
|
+
digestion_buffer=jnp.zeros((self.max_reward_delay,)),
|
355
387
|
)
|
356
388
|
|
357
389
|
return self.get_obs(state, params), state
|
@@ -412,6 +444,12 @@ class ForagaxEnv(environment.Environment):
|
|
412
444
|
int,
|
413
445
|
),
|
414
446
|
"time": spaces.Discrete(params.max_steps_in_episode),
|
447
|
+
"digestion_buffer": spaces.Box(
|
448
|
+
-jnp.inf,
|
449
|
+
jnp.inf,
|
450
|
+
(self.max_reward_delay,),
|
451
|
+
float,
|
452
|
+
),
|
415
453
|
}
|
416
454
|
)
|
417
455
|
|
foragax/objects.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1
1
|
import abc
|
2
|
-
from typing import Tuple
|
2
|
+
from typing import Optional, Tuple
|
3
3
|
|
4
4
|
import jax
|
5
5
|
import jax.numpy as jnp
|
@@ -17,12 +17,14 @@ class BaseForagaxObject:
|
|
17
17
|
collectable: bool = False,
|
18
18
|
color: Tuple[int, int, int] = (0, 0, 0),
|
19
19
|
random_respawn: bool = False,
|
20
|
+
max_reward_delay: int = 0,
|
20
21
|
):
|
21
22
|
self.name = name
|
22
23
|
self.blocking = blocking
|
23
24
|
self.collectable = collectable
|
24
25
|
self.color = color
|
25
26
|
self.random_respawn = random_respawn
|
27
|
+
self.max_reward_delay = max_reward_delay
|
26
28
|
|
27
29
|
@abc.abstractmethod
|
28
30
|
def reward(self, clock: int, rng: jax.Array) -> float:
|
@@ -30,8 +32,8 @@ class BaseForagaxObject:
|
|
30
32
|
raise NotImplementedError
|
31
33
|
|
32
34
|
@abc.abstractmethod
|
33
|
-
def
|
34
|
-
"""
|
35
|
+
def reward_delay(self, clock: int, rng: jax.Array) -> int:
|
36
|
+
"""Reward delay function."""
|
35
37
|
raise NotImplementedError
|
36
38
|
|
37
39
|
|
@@ -47,10 +49,17 @@ class DefaultForagaxObject(BaseForagaxObject):
|
|
47
49
|
regen_delay: Tuple[int, int] = (10, 100),
|
48
50
|
color: Tuple[int, int, int] = (255, 255, 255),
|
49
51
|
random_respawn: bool = False,
|
52
|
+
reward_delay: int = 0,
|
53
|
+
max_reward_delay: Optional[int] = None,
|
50
54
|
):
|
51
|
-
|
55
|
+
if max_reward_delay is None:
|
56
|
+
max_reward_delay = reward_delay
|
57
|
+
super().__init__(
|
58
|
+
name, blocking, collectable, color, random_respawn, max_reward_delay
|
59
|
+
)
|
52
60
|
self.reward_val = reward
|
53
61
|
self.regen_delay_range = regen_delay
|
62
|
+
self.reward_delay_val = reward_delay
|
54
63
|
|
55
64
|
def reward(self, clock: int, rng: jax.Array) -> float:
|
56
65
|
"""Default reward function."""
|
@@ -61,6 +70,10 @@ class DefaultForagaxObject(BaseForagaxObject):
|
|
61
70
|
min_delay, max_delay = self.regen_delay_range
|
62
71
|
return jax.random.randint(rng, (), min_delay, max_delay)
|
63
72
|
|
73
|
+
def reward_delay(self, clock: int, rng: jax.Array) -> int:
|
74
|
+
"""Default reward delay function."""
|
75
|
+
return self.reward_delay_val
|
76
|
+
|
64
77
|
|
65
78
|
class NormalRegenForagaxObject(DefaultForagaxObject):
|
66
79
|
"""Object with regeneration delay from a normal distribution."""
|
@@ -74,6 +87,8 @@ class NormalRegenForagaxObject(DefaultForagaxObject):
|
|
74
87
|
std_regen_delay: int = 1,
|
75
88
|
color: Tuple[int, int, int] = (0, 0, 0),
|
76
89
|
random_respawn: bool = False,
|
90
|
+
reward_delay: int = 0,
|
91
|
+
max_reward_delay: Optional[int] = None,
|
77
92
|
):
|
78
93
|
super().__init__(
|
79
94
|
name=name,
|
@@ -82,6 +97,8 @@ class NormalRegenForagaxObject(DefaultForagaxObject):
|
|
82
97
|
regen_delay=(mean_regen_delay, mean_regen_delay),
|
83
98
|
color=color,
|
84
99
|
random_respawn=random_respawn,
|
100
|
+
reward_delay=reward_delay,
|
101
|
+
max_reward_delay=max_reward_delay,
|
85
102
|
)
|
86
103
|
self.mean_regen_delay = mean_regen_delay
|
87
104
|
self.std_regen_delay = std_regen_delay
|
@@ -105,6 +122,8 @@ class WeatherObject(NormalRegenForagaxObject):
|
|
105
122
|
std_regen_delay: int = 1,
|
106
123
|
color: Tuple[int, int, int] = (0, 0, 0),
|
107
124
|
random_respawn: bool = False,
|
125
|
+
reward_delay: int = 0,
|
126
|
+
max_reward_delay: Optional[int] = None,
|
108
127
|
):
|
109
128
|
super().__init__(
|
110
129
|
name=name,
|
@@ -113,6 +132,8 @@ class WeatherObject(NormalRegenForagaxObject):
|
|
113
132
|
std_regen_delay=std_regen_delay,
|
114
133
|
color=color,
|
115
134
|
random_respawn=random_respawn,
|
135
|
+
reward_delay=reward_delay,
|
136
|
+
max_reward_delay=max_reward_delay,
|
116
137
|
)
|
117
138
|
self.rewards = rewards
|
118
139
|
self.repeat = repeat
|
@@ -319,6 +340,7 @@ def create_weather_objects(
|
|
319
340
|
multiplier: float = 1.0,
|
320
341
|
same_color: bool = False,
|
321
342
|
random_respawn: bool = False,
|
343
|
+
reward_delay: int = 0,
|
322
344
|
):
|
323
345
|
"""Create HOT and COLD WeatherObject instances using the specified file.
|
324
346
|
|
@@ -348,6 +370,7 @@ def create_weather_objects(
|
|
348
370
|
multiplier=multiplier,
|
349
371
|
color=hot_color,
|
350
372
|
random_respawn=random_respawn,
|
373
|
+
reward_delay=reward_delay,
|
351
374
|
)
|
352
375
|
|
353
376
|
cold_color = hot_color if same_color else (0, 255, 255)
|
@@ -358,6 +381,7 @@ def create_weather_objects(
|
|
358
381
|
multiplier=-multiplier,
|
359
382
|
color=cold_color,
|
360
383
|
random_respawn=random_respawn,
|
384
|
+
reward_delay=reward_delay,
|
361
385
|
)
|
362
386
|
|
363
387
|
return hot, cold
|
foragax/registry.py
CHANGED
@@ -348,7 +348,8 @@ def make(
|
|
348
348
|
observation_type: str = "color",
|
349
349
|
aperture_size: Optional[Tuple[int, int]] = (5, 5),
|
350
350
|
file_index: int = 0,
|
351
|
-
|
351
|
+
repeat: int = 500,
|
352
|
+
reward_delay: int = 0,
|
352
353
|
**kwargs: Any,
|
353
354
|
) -> ForagaxEnv:
|
354
355
|
"""Create a Foragax environment.
|
@@ -358,9 +359,9 @@ def make(
|
|
358
359
|
observation_type: The type of observation to use. One of "object", "rgb", or "color".
|
359
360
|
aperture_size: The size of the agent's observation aperture. If -1, full world observation.
|
360
361
|
If None, the default for the environment is used.
|
361
|
-
file_index: File index for weather objects.
|
362
|
-
|
363
|
-
|
362
|
+
file_index: File index for weather objects.
|
363
|
+
repeat: How many steps each temperature value repeats for (weather environments).
|
364
|
+
reward_delay: Number of steps required to digest food items (weather environments).
|
364
365
|
**kwargs: Additional keyword arguments to pass to the ForagaxEnv constructor.
|
365
366
|
|
366
367
|
Returns:
|
@@ -376,8 +377,6 @@ def make(
|
|
376
377
|
else:
|
377
378
|
aperture_size = (aperture_size, aperture_size)
|
378
379
|
config["aperture_size"] = aperture_size
|
379
|
-
if nowrap is not None:
|
380
|
-
config["nowrap"] = nowrap
|
381
380
|
|
382
381
|
# Handle special size and biome configurations
|
383
382
|
if env_id in (
|
@@ -460,9 +459,16 @@ def make(
|
|
460
459
|
"ForagaxWeather-v4",
|
461
460
|
"ForagaxWeather-v5",
|
462
461
|
)
|
463
|
-
random_respawn = env_id in (
|
462
|
+
random_respawn = env_id in (
|
463
|
+
"ForagaxWeather-v4",
|
464
|
+
"ForagaxWeather-v5",
|
465
|
+
)
|
464
466
|
hot, cold = create_weather_objects(
|
465
|
-
file_index=file_index,
|
467
|
+
file_index=file_index,
|
468
|
+
repeat=repeat,
|
469
|
+
same_color=same_color,
|
470
|
+
random_respawn=random_respawn,
|
471
|
+
reward_delay=reward_delay,
|
466
472
|
)
|
467
473
|
config["objects"] = (hot, cold)
|
468
474
|
|
File without changes
|
File without changes
|
File without changes
|