continual-foragax 0.29.0__py3-none-any.whl → 0.30.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {continual_foragax-0.29.0.dist-info → continual_foragax-0.30.0.dist-info}/METADATA +1 -1
- {continual_foragax-0.29.0.dist-info → continual_foragax-0.30.0.dist-info}/RECORD +8 -8
- foragax/env.py +14 -14
- foragax/objects.py +24 -24
- foragax/registry.py +6 -17
- {continual_foragax-0.29.0.dist-info → continual_foragax-0.30.0.dist-info}/WHEEL +0 -0
- {continual_foragax-0.29.0.dist-info → continual_foragax-0.30.0.dist-info}/entry_points.txt +0 -0
- {continual_foragax-0.29.0.dist-info → continual_foragax-0.30.0.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,8 @@
|
|
1
1
|
foragax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
foragax/colors.py,sha256=rqNPiywP4Nvr0POhsGpasRk-nMMTS3DOwFRUgperlUk,2065
|
3
|
-
foragax/env.py,sha256=
|
4
|
-
foragax/objects.py,sha256=
|
5
|
-
foragax/registry.py,sha256=
|
3
|
+
foragax/env.py,sha256=4NZ5JsUGjAepmzw2uxu5_ikyVZnZ7vazy062Xzx22Zg,27481
|
4
|
+
foragax/objects.py,sha256=0vb_iyr62BKaIxiE3JwtRhZhFE3VFM6PdxDZTaDtv24,10410
|
5
|
+
foragax/registry.py,sha256=Dxg6cWIPwg91fNrCPxADJv35u6jFg_8dI5iTpCMFEFA,15229
|
6
6
|
foragax/rendering.py,sha256=bms7wvBZTofoR-K-2QD2Ggeed7Viw8uwAEiEpEM3eSo,2768
|
7
7
|
foragax/weather.py,sha256=KNAiwuFz8V__6G75vZIWQKPocLzXqxXn-Vt4TbHIpcA,1258
|
8
8
|
foragax/data/ECA_non-blended_custom/TG_SOUID100897.txt,sha256=N7URbX6VlCZvCboUogYjMzy1I-0cfNPOn0QTLSHHfQ0,1776751
|
@@ -128,8 +128,8 @@ foragax/data/ECA_non-blended_custom/TG_SOUID156887.txt,sha256=juzTPgJoJxfqmZkorL
|
|
128
128
|
foragax/data/ECA_non-blended_custom/elements.txt,sha256=OtcUBoDAHxuln79BPKGu0tsQxG_5G2BfAX3Ck130kEA,4507
|
129
129
|
foragax/data/ECA_non-blended_custom/metadata.txt,sha256=nudnmOCy5cPJfSXt_IjyX0S5-T7NkCZREICZSimqeqc,48260
|
130
130
|
foragax/data/ECA_non-blended_custom/sources.txt,sha256=1j3lSmINAoCMqPqFrHfZJriOz6sTYZNOhXzUwvTLas0,20857
|
131
|
-
continual_foragax-0.
|
132
|
-
continual_foragax-0.
|
133
|
-
continual_foragax-0.
|
134
|
-
continual_foragax-0.
|
135
|
-
continual_foragax-0.
|
131
|
+
continual_foragax-0.30.0.dist-info/METADATA,sha256=d0xeSz0BvDVe1lOUGdhVyqnbkkYN7dNW4BPfCnDSZfQ,4897
|
132
|
+
continual_foragax-0.30.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
133
|
+
continual_foragax-0.30.0.dist-info/entry_points.txt,sha256=Qiu6iE_XudrDO_bVAMeA435h4PO9ourt8huvSHiuMPc,41
|
134
|
+
continual_foragax-0.30.0.dist-info/top_level.txt,sha256=-z3SDK6RfLIcLI24n8rdbeFzlVY3hunChzlu-v1Fncs,8
|
135
|
+
continual_foragax-0.30.0.dist-info/RECORD,,
|
foragax/env.py
CHANGED
@@ -113,12 +113,12 @@ class ForagaxEnv(environment.Environment):
|
|
113
113
|
|
114
114
|
self.reward_fns = [o.reward for o in objects]
|
115
115
|
self.regen_delay_fns = [o.regen_delay for o in objects]
|
116
|
-
self.
|
116
|
+
self.reward_delay_fns = [o.reward_delay for o in objects]
|
117
117
|
|
118
|
-
# Compute
|
119
|
-
|
120
|
-
self.
|
121
|
-
int(jnp.max(
|
118
|
+
# Compute reward steps per object (using max_reward_delay attribute)
|
119
|
+
object_max_reward_delay = jnp.array([o.max_reward_delay for o in objects])
|
120
|
+
self.max_reward_delay = (
|
121
|
+
int(jnp.max(object_max_reward_delay)) + 1 if len(objects) > 0 else 0
|
122
122
|
)
|
123
123
|
|
124
124
|
self.biome_object_frequencies = jnp.array(
|
@@ -250,21 +250,21 @@ class ForagaxEnv(environment.Environment):
|
|
250
250
|
obj_at_pos, self.reward_fns, state.time, reward_subkey
|
251
251
|
)
|
252
252
|
key, digestion_subkey = jax.random.split(key)
|
253
|
-
|
254
|
-
obj_at_pos, self.
|
253
|
+
reward_delay = jax.lax.switch(
|
254
|
+
obj_at_pos, self.reward_delay_fns, state.time, digestion_subkey
|
255
255
|
)
|
256
|
-
reward = jnp.where(should_collect & (
|
257
|
-
if self.
|
256
|
+
reward = jnp.where(should_collect & (reward_delay == 0), object_reward, 0.0)
|
257
|
+
if self.max_reward_delay > 0:
|
258
258
|
# Add delayed rewards to buffer
|
259
259
|
digestion_buffer = jax.lax.cond(
|
260
|
-
should_collect & (
|
260
|
+
should_collect & (reward_delay > 0),
|
261
261
|
lambda: digestion_buffer.at[
|
262
|
-
(state.time +
|
262
|
+
(state.time + reward_delay) % self.max_reward_delay
|
263
263
|
].add(object_reward),
|
264
264
|
lambda: digestion_buffer,
|
265
265
|
)
|
266
266
|
# Deliver current rewards
|
267
|
-
current_index = state.time % self.
|
267
|
+
current_index = state.time % self.max_reward_delay
|
268
268
|
reward += digestion_buffer[current_index]
|
269
269
|
digestion_buffer = digestion_buffer.at[current_index].set(0.0)
|
270
270
|
|
@@ -383,7 +383,7 @@ class ForagaxEnv(environment.Environment):
|
|
383
383
|
object_grid=object_grid,
|
384
384
|
biome_grid=biome_grid,
|
385
385
|
time=0,
|
386
|
-
digestion_buffer=jnp.zeros((self.
|
386
|
+
digestion_buffer=jnp.zeros((self.max_reward_delay,)),
|
387
387
|
)
|
388
388
|
|
389
389
|
return self.get_obs(state, params), state
|
@@ -447,7 +447,7 @@ class ForagaxEnv(environment.Environment):
|
|
447
447
|
"digestion_buffer": spaces.Box(
|
448
448
|
-jnp.inf,
|
449
449
|
jnp.inf,
|
450
|
-
(self.
|
450
|
+
(self.max_reward_delay,),
|
451
451
|
float,
|
452
452
|
),
|
453
453
|
}
|
foragax/objects.py
CHANGED
@@ -17,14 +17,14 @@ class BaseForagaxObject:
|
|
17
17
|
collectable: bool = False,
|
18
18
|
color: Tuple[int, int, int] = (0, 0, 0),
|
19
19
|
random_respawn: bool = False,
|
20
|
-
|
20
|
+
max_reward_delay: int = 0,
|
21
21
|
):
|
22
22
|
self.name = name
|
23
23
|
self.blocking = blocking
|
24
24
|
self.collectable = collectable
|
25
25
|
self.color = color
|
26
26
|
self.random_respawn = random_respawn
|
27
|
-
self.
|
27
|
+
self.max_reward_delay = max_reward_delay
|
28
28
|
|
29
29
|
@abc.abstractmethod
|
30
30
|
def reward(self, clock: int, rng: jax.Array) -> float:
|
@@ -32,8 +32,8 @@ class BaseForagaxObject:
|
|
32
32
|
raise NotImplementedError
|
33
33
|
|
34
34
|
@abc.abstractmethod
|
35
|
-
def
|
36
|
-
"""
|
35
|
+
def reward_delay(self, clock: int, rng: jax.Array) -> int:
|
36
|
+
"""Reward delay function."""
|
37
37
|
raise NotImplementedError
|
38
38
|
|
39
39
|
|
@@ -49,17 +49,17 @@ class DefaultForagaxObject(BaseForagaxObject):
|
|
49
49
|
regen_delay: Tuple[int, int] = (10, 100),
|
50
50
|
color: Tuple[int, int, int] = (255, 255, 255),
|
51
51
|
random_respawn: bool = False,
|
52
|
-
|
53
|
-
|
52
|
+
reward_delay: int = 0,
|
53
|
+
max_reward_delay: Optional[int] = None,
|
54
54
|
):
|
55
|
-
if
|
56
|
-
|
55
|
+
if max_reward_delay is None:
|
56
|
+
max_reward_delay = reward_delay
|
57
57
|
super().__init__(
|
58
|
-
name, blocking, collectable, color, random_respawn,
|
58
|
+
name, blocking, collectable, color, random_respawn, max_reward_delay
|
59
59
|
)
|
60
60
|
self.reward_val = reward
|
61
61
|
self.regen_delay_range = regen_delay
|
62
|
-
self.
|
62
|
+
self.reward_delay_val = reward_delay
|
63
63
|
|
64
64
|
def reward(self, clock: int, rng: jax.Array) -> float:
|
65
65
|
"""Default reward function."""
|
@@ -70,9 +70,9 @@ class DefaultForagaxObject(BaseForagaxObject):
|
|
70
70
|
min_delay, max_delay = self.regen_delay_range
|
71
71
|
return jax.random.randint(rng, (), min_delay, max_delay)
|
72
72
|
|
73
|
-
def
|
74
|
-
"""Default
|
75
|
-
return self.
|
73
|
+
def reward_delay(self, clock: int, rng: jax.Array) -> int:
|
74
|
+
"""Default reward delay function."""
|
75
|
+
return self.reward_delay_val
|
76
76
|
|
77
77
|
|
78
78
|
class NormalRegenForagaxObject(DefaultForagaxObject):
|
@@ -87,8 +87,8 @@ class NormalRegenForagaxObject(DefaultForagaxObject):
|
|
87
87
|
std_regen_delay: int = 1,
|
88
88
|
color: Tuple[int, int, int] = (0, 0, 0),
|
89
89
|
random_respawn: bool = False,
|
90
|
-
|
91
|
-
|
90
|
+
reward_delay: int = 0,
|
91
|
+
max_reward_delay: Optional[int] = None,
|
92
92
|
):
|
93
93
|
super().__init__(
|
94
94
|
name=name,
|
@@ -97,8 +97,8 @@ class NormalRegenForagaxObject(DefaultForagaxObject):
|
|
97
97
|
regen_delay=(mean_regen_delay, mean_regen_delay),
|
98
98
|
color=color,
|
99
99
|
random_respawn=random_respawn,
|
100
|
-
|
101
|
-
|
100
|
+
reward_delay=reward_delay,
|
101
|
+
max_reward_delay=max_reward_delay,
|
102
102
|
)
|
103
103
|
self.mean_regen_delay = mean_regen_delay
|
104
104
|
self.std_regen_delay = std_regen_delay
|
@@ -122,8 +122,8 @@ class WeatherObject(NormalRegenForagaxObject):
|
|
122
122
|
std_regen_delay: int = 1,
|
123
123
|
color: Tuple[int, int, int] = (0, 0, 0),
|
124
124
|
random_respawn: bool = False,
|
125
|
-
|
126
|
-
|
125
|
+
reward_delay: int = 0,
|
126
|
+
max_reward_delay: Optional[int] = None,
|
127
127
|
):
|
128
128
|
super().__init__(
|
129
129
|
name=name,
|
@@ -132,8 +132,8 @@ class WeatherObject(NormalRegenForagaxObject):
|
|
132
132
|
std_regen_delay=std_regen_delay,
|
133
133
|
color=color,
|
134
134
|
random_respawn=random_respawn,
|
135
|
-
|
136
|
-
|
135
|
+
reward_delay=reward_delay,
|
136
|
+
max_reward_delay=max_reward_delay,
|
137
137
|
)
|
138
138
|
self.rewards = rewards
|
139
139
|
self.repeat = repeat
|
@@ -340,7 +340,7 @@ def create_weather_objects(
|
|
340
340
|
multiplier: float = 1.0,
|
341
341
|
same_color: bool = False,
|
342
342
|
random_respawn: bool = False,
|
343
|
-
|
343
|
+
reward_delay: int = 0,
|
344
344
|
):
|
345
345
|
"""Create HOT and COLD WeatherObject instances using the specified file.
|
346
346
|
|
@@ -370,7 +370,7 @@ def create_weather_objects(
|
|
370
370
|
multiplier=multiplier,
|
371
371
|
color=hot_color,
|
372
372
|
random_respawn=random_respawn,
|
373
|
-
|
373
|
+
reward_delay=reward_delay,
|
374
374
|
)
|
375
375
|
|
376
376
|
cold_color = hot_color if same_color else (0, 255, 255)
|
@@ -381,7 +381,7 @@ def create_weather_objects(
|
|
381
381
|
multiplier=-multiplier,
|
382
382
|
color=cold_color,
|
383
383
|
random_respawn=random_respawn,
|
384
|
-
|
384
|
+
reward_delay=reward_delay,
|
385
385
|
)
|
386
386
|
|
387
387
|
return hot, cold
|
foragax/registry.py
CHANGED
@@ -83,19 +83,6 @@ ENV_CONFIGS: Dict[str, Dict[str, Any]] = {
|
|
83
83
|
"nowrap": False,
|
84
84
|
"deterministic_spawn": True,
|
85
85
|
},
|
86
|
-
"ForagaxWeather-v6": {
|
87
|
-
"size": (15, 15),
|
88
|
-
"aperture_size": None,
|
89
|
-
"objects": None,
|
90
|
-
"biomes": (
|
91
|
-
# Hot biome
|
92
|
-
Biome(start=(0, 3), stop=(15, 5), object_frequencies=(0.5, 0.0)),
|
93
|
-
# Cold biome
|
94
|
-
Biome(start=(0, 10), stop=(15, 12), object_frequencies=(0.0, 0.5)),
|
95
|
-
),
|
96
|
-
"nowrap": False,
|
97
|
-
"deterministic_spawn": True,
|
98
|
-
},
|
99
86
|
"ForagaxTwoBiome-v1": {
|
100
87
|
"size": (15, 15),
|
101
88
|
"aperture_size": None,
|
@@ -361,6 +348,8 @@ def make(
|
|
361
348
|
observation_type: str = "color",
|
362
349
|
aperture_size: Optional[Tuple[int, int]] = (5, 5),
|
363
350
|
file_index: int = 0,
|
351
|
+
repeat: int = 500,
|
352
|
+
reward_delay: int = 0,
|
364
353
|
**kwargs: Any,
|
365
354
|
) -> ForagaxEnv:
|
366
355
|
"""Create a Foragax environment.
|
@@ -371,6 +360,8 @@ def make(
|
|
371
360
|
aperture_size: The size of the agent's observation aperture. If -1, full world observation.
|
372
361
|
If None, the default for the environment is used.
|
373
362
|
file_index: File index for weather objects.
|
363
|
+
repeat: How many steps each temperature value repeats for (weather environments).
|
364
|
+
reward_delay: Number of steps required to digest food items (weather environments).
|
374
365
|
**kwargs: Additional keyword arguments to pass to the ForagaxEnv constructor.
|
375
366
|
|
376
367
|
Returns:
|
@@ -467,19 +458,17 @@ def make(
|
|
467
458
|
"ForagaxWeather-v3",
|
468
459
|
"ForagaxWeather-v4",
|
469
460
|
"ForagaxWeather-v5",
|
470
|
-
"ForagaxWeather-v6",
|
471
461
|
)
|
472
462
|
random_respawn = env_id in (
|
473
463
|
"ForagaxWeather-v4",
|
474
464
|
"ForagaxWeather-v5",
|
475
|
-
"ForagaxWeather-v6",
|
476
465
|
)
|
477
|
-
digestion_steps = 10 if env_id in ("ForagaxWeather-v6") else 0
|
478
466
|
hot, cold = create_weather_objects(
|
479
467
|
file_index=file_index,
|
468
|
+
repeat=repeat,
|
480
469
|
same_color=same_color,
|
481
470
|
random_respawn=random_respawn,
|
482
|
-
|
471
|
+
reward_delay=reward_delay,
|
483
472
|
)
|
484
473
|
config["objects"] = (hot, cold)
|
485
474
|
|
File without changes
|
File without changes
|
File without changes
|