continual-foragax 0.29.0__py3-none-any.whl → 0.30.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {continual_foragax-0.29.0.dist-info → continual_foragax-0.30.1.dist-info}/METADATA +1 -1
- {continual_foragax-0.29.0.dist-info → continual_foragax-0.30.1.dist-info}/RECORD +8 -8
- foragax/env.py +14 -14
- foragax/objects.py +26 -27
- foragax/registry.py +6 -17
- {continual_foragax-0.29.0.dist-info → continual_foragax-0.30.1.dist-info}/WHEEL +0 -0
- {continual_foragax-0.29.0.dist-info → continual_foragax-0.30.1.dist-info}/entry_points.txt +0 -0
- {continual_foragax-0.29.0.dist-info → continual_foragax-0.30.1.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,8 @@
|
|
1
1
|
foragax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
foragax/colors.py,sha256=rqNPiywP4Nvr0POhsGpasRk-nMMTS3DOwFRUgperlUk,2065
|
3
|
-
foragax/env.py,sha256=
|
4
|
-
foragax/objects.py,sha256=
|
5
|
-
foragax/registry.py,sha256=
|
3
|
+
foragax/env.py,sha256=4NZ5JsUGjAepmzw2uxu5_ikyVZnZ7vazy062Xzx22Zg,27481
|
4
|
+
foragax/objects.py,sha256=M0nECANGfUvvBRMKSS7akGtoO2Suv5eroI-9Aj326sw,10368
|
5
|
+
foragax/registry.py,sha256=Dxg6cWIPwg91fNrCPxADJv35u6jFg_8dI5iTpCMFEFA,15229
|
6
6
|
foragax/rendering.py,sha256=bms7wvBZTofoR-K-2QD2Ggeed7Viw8uwAEiEpEM3eSo,2768
|
7
7
|
foragax/weather.py,sha256=KNAiwuFz8V__6G75vZIWQKPocLzXqxXn-Vt4TbHIpcA,1258
|
8
8
|
foragax/data/ECA_non-blended_custom/TG_SOUID100897.txt,sha256=N7URbX6VlCZvCboUogYjMzy1I-0cfNPOn0QTLSHHfQ0,1776751
|
@@ -128,8 +128,8 @@ foragax/data/ECA_non-blended_custom/TG_SOUID156887.txt,sha256=juzTPgJoJxfqmZkorL
|
|
128
128
|
foragax/data/ECA_non-blended_custom/elements.txt,sha256=OtcUBoDAHxuln79BPKGu0tsQxG_5G2BfAX3Ck130kEA,4507
|
129
129
|
foragax/data/ECA_non-blended_custom/metadata.txt,sha256=nudnmOCy5cPJfSXt_IjyX0S5-T7NkCZREICZSimqeqc,48260
|
130
130
|
foragax/data/ECA_non-blended_custom/sources.txt,sha256=1j3lSmINAoCMqPqFrHfZJriOz6sTYZNOhXzUwvTLas0,20857
|
131
|
-
continual_foragax-0.
|
132
|
-
continual_foragax-0.
|
133
|
-
continual_foragax-0.
|
134
|
-
continual_foragax-0.
|
135
|
-
continual_foragax-0.
|
131
|
+
continual_foragax-0.30.1.dist-info/METADATA,sha256=9iwHDGT1ZbvjL_CNRFQRQsPBPWKTFBetxrJPk_OXKug,4897
|
132
|
+
continual_foragax-0.30.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
133
|
+
continual_foragax-0.30.1.dist-info/entry_points.txt,sha256=Qiu6iE_XudrDO_bVAMeA435h4PO9ourt8huvSHiuMPc,41
|
134
|
+
continual_foragax-0.30.1.dist-info/top_level.txt,sha256=-z3SDK6RfLIcLI24n8rdbeFzlVY3hunChzlu-v1Fncs,8
|
135
|
+
continual_foragax-0.30.1.dist-info/RECORD,,
|
foragax/env.py
CHANGED
@@ -113,12 +113,12 @@ class ForagaxEnv(environment.Environment):
|
|
113
113
|
|
114
114
|
self.reward_fns = [o.reward for o in objects]
|
115
115
|
self.regen_delay_fns = [o.regen_delay for o in objects]
|
116
|
-
self.
|
116
|
+
self.reward_delay_fns = [o.reward_delay for o in objects]
|
117
117
|
|
118
|
-
# Compute
|
119
|
-
|
120
|
-
self.
|
121
|
-
int(jnp.max(
|
118
|
+
# Compute reward steps per object (using max_reward_delay attribute)
|
119
|
+
object_max_reward_delay = jnp.array([o.max_reward_delay for o in objects])
|
120
|
+
self.max_reward_delay = (
|
121
|
+
int(jnp.max(object_max_reward_delay)) + 1 if len(objects) > 0 else 0
|
122
122
|
)
|
123
123
|
|
124
124
|
self.biome_object_frequencies = jnp.array(
|
@@ -250,21 +250,21 @@ class ForagaxEnv(environment.Environment):
|
|
250
250
|
obj_at_pos, self.reward_fns, state.time, reward_subkey
|
251
251
|
)
|
252
252
|
key, digestion_subkey = jax.random.split(key)
|
253
|
-
|
254
|
-
obj_at_pos, self.
|
253
|
+
reward_delay = jax.lax.switch(
|
254
|
+
obj_at_pos, self.reward_delay_fns, state.time, digestion_subkey
|
255
255
|
)
|
256
|
-
reward = jnp.where(should_collect & (
|
257
|
-
if self.
|
256
|
+
reward = jnp.where(should_collect & (reward_delay == 0), object_reward, 0.0)
|
257
|
+
if self.max_reward_delay > 0:
|
258
258
|
# Add delayed rewards to buffer
|
259
259
|
digestion_buffer = jax.lax.cond(
|
260
|
-
should_collect & (
|
260
|
+
should_collect & (reward_delay > 0),
|
261
261
|
lambda: digestion_buffer.at[
|
262
|
-
(state.time +
|
262
|
+
(state.time + reward_delay) % self.max_reward_delay
|
263
263
|
].add(object_reward),
|
264
264
|
lambda: digestion_buffer,
|
265
265
|
)
|
266
266
|
# Deliver current rewards
|
267
|
-
current_index = state.time % self.
|
267
|
+
current_index = state.time % self.max_reward_delay
|
268
268
|
reward += digestion_buffer[current_index]
|
269
269
|
digestion_buffer = digestion_buffer.at[current_index].set(0.0)
|
270
270
|
|
@@ -383,7 +383,7 @@ class ForagaxEnv(environment.Environment):
|
|
383
383
|
object_grid=object_grid,
|
384
384
|
biome_grid=biome_grid,
|
385
385
|
time=0,
|
386
|
-
digestion_buffer=jnp.zeros((self.
|
386
|
+
digestion_buffer=jnp.zeros((self.max_reward_delay,)),
|
387
387
|
)
|
388
388
|
|
389
389
|
return self.get_obs(state, params), state
|
@@ -447,7 +447,7 @@ class ForagaxEnv(environment.Environment):
|
|
447
447
|
"digestion_buffer": spaces.Box(
|
448
448
|
-jnp.inf,
|
449
449
|
jnp.inf,
|
450
|
-
(self.
|
450
|
+
(self.max_reward_delay,),
|
451
451
|
float,
|
452
452
|
),
|
453
453
|
}
|
foragax/objects.py
CHANGED
@@ -17,14 +17,14 @@ class BaseForagaxObject:
|
|
17
17
|
collectable: bool = False,
|
18
18
|
color: Tuple[int, int, int] = (0, 0, 0),
|
19
19
|
random_respawn: bool = False,
|
20
|
-
|
20
|
+
max_reward_delay: int = 0,
|
21
21
|
):
|
22
22
|
self.name = name
|
23
23
|
self.blocking = blocking
|
24
24
|
self.collectable = collectable
|
25
25
|
self.color = color
|
26
26
|
self.random_respawn = random_respawn
|
27
|
-
self.
|
27
|
+
self.max_reward_delay = max_reward_delay
|
28
28
|
|
29
29
|
@abc.abstractmethod
|
30
30
|
def reward(self, clock: int, rng: jax.Array) -> float:
|
@@ -32,8 +32,8 @@ class BaseForagaxObject:
|
|
32
32
|
raise NotImplementedError
|
33
33
|
|
34
34
|
@abc.abstractmethod
|
35
|
-
def
|
36
|
-
"""
|
35
|
+
def reward_delay(self, clock: int, rng: jax.Array) -> int:
|
36
|
+
"""Reward delay function."""
|
37
37
|
raise NotImplementedError
|
38
38
|
|
39
39
|
|
@@ -49,17 +49,17 @@ class DefaultForagaxObject(BaseForagaxObject):
|
|
49
49
|
regen_delay: Tuple[int, int] = (10, 100),
|
50
50
|
color: Tuple[int, int, int] = (255, 255, 255),
|
51
51
|
random_respawn: bool = False,
|
52
|
-
|
53
|
-
|
52
|
+
reward_delay: int = 0,
|
53
|
+
max_reward_delay: Optional[int] = None,
|
54
54
|
):
|
55
|
-
if
|
56
|
-
|
55
|
+
if max_reward_delay is None:
|
56
|
+
max_reward_delay = reward_delay
|
57
57
|
super().__init__(
|
58
|
-
name, blocking, collectable, color, random_respawn,
|
58
|
+
name, blocking, collectable, color, random_respawn, max_reward_delay
|
59
59
|
)
|
60
60
|
self.reward_val = reward
|
61
61
|
self.regen_delay_range = regen_delay
|
62
|
-
self.
|
62
|
+
self.reward_delay_val = reward_delay
|
63
63
|
|
64
64
|
def reward(self, clock: int, rng: jax.Array) -> float:
|
65
65
|
"""Default reward function."""
|
@@ -70,9 +70,9 @@ class DefaultForagaxObject(BaseForagaxObject):
|
|
70
70
|
min_delay, max_delay = self.regen_delay_range
|
71
71
|
return jax.random.randint(rng, (), min_delay, max_delay)
|
72
72
|
|
73
|
-
def
|
74
|
-
"""Default
|
75
|
-
return self.
|
73
|
+
def reward_delay(self, clock: int, rng: jax.Array) -> int:
|
74
|
+
"""Default reward delay function."""
|
75
|
+
return self.reward_delay_val
|
76
76
|
|
77
77
|
|
78
78
|
class NormalRegenForagaxObject(DefaultForagaxObject):
|
@@ -87,8 +87,8 @@ class NormalRegenForagaxObject(DefaultForagaxObject):
|
|
87
87
|
std_regen_delay: int = 1,
|
88
88
|
color: Tuple[int, int, int] = (0, 0, 0),
|
89
89
|
random_respawn: bool = False,
|
90
|
-
|
91
|
-
|
90
|
+
reward_delay: int = 0,
|
91
|
+
max_reward_delay: Optional[int] = None,
|
92
92
|
):
|
93
93
|
super().__init__(
|
94
94
|
name=name,
|
@@ -97,8 +97,8 @@ class NormalRegenForagaxObject(DefaultForagaxObject):
|
|
97
97
|
regen_delay=(mean_regen_delay, mean_regen_delay),
|
98
98
|
color=color,
|
99
99
|
random_respawn=random_respawn,
|
100
|
-
|
101
|
-
|
100
|
+
reward_delay=reward_delay,
|
101
|
+
max_reward_delay=max_reward_delay,
|
102
102
|
)
|
103
103
|
self.mean_regen_delay = mean_regen_delay
|
104
104
|
self.std_regen_delay = std_regen_delay
|
@@ -122,8 +122,8 @@ class WeatherObject(NormalRegenForagaxObject):
|
|
122
122
|
std_regen_delay: int = 1,
|
123
123
|
color: Tuple[int, int, int] = (0, 0, 0),
|
124
124
|
random_respawn: bool = False,
|
125
|
-
|
126
|
-
|
125
|
+
reward_delay: int = 0,
|
126
|
+
max_reward_delay: Optional[int] = None,
|
127
127
|
):
|
128
128
|
super().__init__(
|
129
129
|
name=name,
|
@@ -132,16 +132,15 @@ class WeatherObject(NormalRegenForagaxObject):
|
|
132
132
|
std_regen_delay=std_regen_delay,
|
133
133
|
color=color,
|
134
134
|
random_respawn=random_respawn,
|
135
|
-
|
136
|
-
|
135
|
+
reward_delay=reward_delay,
|
136
|
+
max_reward_delay=max_reward_delay,
|
137
137
|
)
|
138
|
-
self.rewards = rewards
|
138
|
+
self.rewards = rewards * multiplier
|
139
139
|
self.repeat = repeat
|
140
|
-
self.multiplier = multiplier
|
141
140
|
|
142
141
|
def reward(self, clock: int, rng: jax.Array) -> float:
|
143
142
|
"""Reward is based on temperature."""
|
144
|
-
return get_temperature(self.rewards, clock, self.repeat)
|
143
|
+
return get_temperature(self.rewards, clock, self.repeat)
|
145
144
|
|
146
145
|
|
147
146
|
EMPTY = DefaultForagaxObject()
|
@@ -340,7 +339,7 @@ def create_weather_objects(
|
|
340
339
|
multiplier: float = 1.0,
|
341
340
|
same_color: bool = False,
|
342
341
|
random_respawn: bool = False,
|
343
|
-
|
342
|
+
reward_delay: int = 0,
|
344
343
|
):
|
345
344
|
"""Create HOT and COLD WeatherObject instances using the specified file.
|
346
345
|
|
@@ -370,7 +369,7 @@ def create_weather_objects(
|
|
370
369
|
multiplier=multiplier,
|
371
370
|
color=hot_color,
|
372
371
|
random_respawn=random_respawn,
|
373
|
-
|
372
|
+
reward_delay=reward_delay,
|
374
373
|
)
|
375
374
|
|
376
375
|
cold_color = hot_color if same_color else (0, 255, 255)
|
@@ -381,7 +380,7 @@ def create_weather_objects(
|
|
381
380
|
multiplier=-multiplier,
|
382
381
|
color=cold_color,
|
383
382
|
random_respawn=random_respawn,
|
384
|
-
|
383
|
+
reward_delay=reward_delay,
|
385
384
|
)
|
386
385
|
|
387
386
|
return hot, cold
|
foragax/registry.py
CHANGED
@@ -83,19 +83,6 @@ ENV_CONFIGS: Dict[str, Dict[str, Any]] = {
|
|
83
83
|
"nowrap": False,
|
84
84
|
"deterministic_spawn": True,
|
85
85
|
},
|
86
|
-
"ForagaxWeather-v6": {
|
87
|
-
"size": (15, 15),
|
88
|
-
"aperture_size": None,
|
89
|
-
"objects": None,
|
90
|
-
"biomes": (
|
91
|
-
# Hot biome
|
92
|
-
Biome(start=(0, 3), stop=(15, 5), object_frequencies=(0.5, 0.0)),
|
93
|
-
# Cold biome
|
94
|
-
Biome(start=(0, 10), stop=(15, 12), object_frequencies=(0.0, 0.5)),
|
95
|
-
),
|
96
|
-
"nowrap": False,
|
97
|
-
"deterministic_spawn": True,
|
98
|
-
},
|
99
86
|
"ForagaxTwoBiome-v1": {
|
100
87
|
"size": (15, 15),
|
101
88
|
"aperture_size": None,
|
@@ -361,6 +348,8 @@ def make(
|
|
361
348
|
observation_type: str = "color",
|
362
349
|
aperture_size: Optional[Tuple[int, int]] = (5, 5),
|
363
350
|
file_index: int = 0,
|
351
|
+
repeat: int = 500,
|
352
|
+
reward_delay: int = 0,
|
364
353
|
**kwargs: Any,
|
365
354
|
) -> ForagaxEnv:
|
366
355
|
"""Create a Foragax environment.
|
@@ -371,6 +360,8 @@ def make(
|
|
371
360
|
aperture_size: The size of the agent's observation aperture. If -1, full world observation.
|
372
361
|
If None, the default for the environment is used.
|
373
362
|
file_index: File index for weather objects.
|
363
|
+
repeat: How many steps each temperature value repeats for (weather environments).
|
364
|
+
reward_delay: Number of steps required to digest food items (weather environments).
|
374
365
|
**kwargs: Additional keyword arguments to pass to the ForagaxEnv constructor.
|
375
366
|
|
376
367
|
Returns:
|
@@ -467,19 +458,17 @@ def make(
|
|
467
458
|
"ForagaxWeather-v3",
|
468
459
|
"ForagaxWeather-v4",
|
469
460
|
"ForagaxWeather-v5",
|
470
|
-
"ForagaxWeather-v6",
|
471
461
|
)
|
472
462
|
random_respawn = env_id in (
|
473
463
|
"ForagaxWeather-v4",
|
474
464
|
"ForagaxWeather-v5",
|
475
|
-
"ForagaxWeather-v6",
|
476
465
|
)
|
477
|
-
digestion_steps = 10 if env_id in ("ForagaxWeather-v6") else 0
|
478
466
|
hot, cold = create_weather_objects(
|
479
467
|
file_index=file_index,
|
468
|
+
repeat=repeat,
|
480
469
|
same_color=same_color,
|
481
470
|
random_respawn=random_respawn,
|
482
|
-
|
471
|
+
reward_delay=reward_delay,
|
483
472
|
)
|
484
473
|
config["objects"] = (hot, cold)
|
485
474
|
|
File without changes
|
File without changes
|
File without changes
|