continual-foragax 0.28.0__py3-none-any.whl → 0.29.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {continual_foragax-0.28.0.dist-info → continual_foragax-0.29.0.dist-info}/METADATA +1 -1
- {continual_foragax-0.28.0.dist-info → continual_foragax-0.29.0.dist-info}/RECORD +8 -8
- foragax/env.py +52 -14
- foragax/objects.py +28 -4
- foragax/registry.py +25 -8
- {continual_foragax-0.28.0.dist-info → continual_foragax-0.29.0.dist-info}/WHEEL +0 -0
- {continual_foragax-0.28.0.dist-info → continual_foragax-0.29.0.dist-info}/entry_points.txt +0 -0
- {continual_foragax-0.28.0.dist-info → continual_foragax-0.29.0.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,8 @@
|
|
1
1
|
foragax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
foragax/colors.py,sha256=rqNPiywP4Nvr0POhsGpasRk-nMMTS3DOwFRUgperlUk,2065
|
3
|
-
foragax/env.py,sha256=
|
4
|
-
foragax/objects.py,sha256=
|
5
|
-
foragax/registry.py,sha256=
|
3
|
+
foragax/env.py,sha256=OFtDT8c5nskflnXmMNbRUvu5pVs9vhIezDgZaICXSyE,27535
|
4
|
+
foragax/objects.py,sha256=j7FivgT4uz6N4FkOTmpM0t-YjTUkYUBLAznWpVqqjrU,10509
|
5
|
+
foragax/registry.py,sha256=bFSTDo7XU4G0njHjLTgPdzWyStkAGckimDYcnGYLIIg,15529
|
6
6
|
foragax/rendering.py,sha256=bms7wvBZTofoR-K-2QD2Ggeed7Viw8uwAEiEpEM3eSo,2768
|
7
7
|
foragax/weather.py,sha256=KNAiwuFz8V__6G75vZIWQKPocLzXqxXn-Vt4TbHIpcA,1258
|
8
8
|
foragax/data/ECA_non-blended_custom/TG_SOUID100897.txt,sha256=N7URbX6VlCZvCboUogYjMzy1I-0cfNPOn0QTLSHHfQ0,1776751
|
@@ -128,8 +128,8 @@ foragax/data/ECA_non-blended_custom/TG_SOUID156887.txt,sha256=juzTPgJoJxfqmZkorL
|
|
128
128
|
foragax/data/ECA_non-blended_custom/elements.txt,sha256=OtcUBoDAHxuln79BPKGu0tsQxG_5G2BfAX3Ck130kEA,4507
|
129
129
|
foragax/data/ECA_non-blended_custom/metadata.txt,sha256=nudnmOCy5cPJfSXt_IjyX0S5-T7NkCZREICZSimqeqc,48260
|
130
130
|
foragax/data/ECA_non-blended_custom/sources.txt,sha256=1j3lSmINAoCMqPqFrHfZJriOz6sTYZNOhXzUwvTLas0,20857
|
131
|
-
continual_foragax-0.
|
132
|
-
continual_foragax-0.
|
133
|
-
continual_foragax-0.
|
134
|
-
continual_foragax-0.
|
135
|
-
continual_foragax-0.
|
131
|
+
continual_foragax-0.29.0.dist-info/METADATA,sha256=vaiiCWr06OczH8LgsRotJXRt7q4KZMtnvuFj6up5v3U,4897
|
132
|
+
continual_foragax-0.29.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
133
|
+
continual_foragax-0.29.0.dist-info/entry_points.txt,sha256=Qiu6iE_XudrDO_bVAMeA435h4PO9ourt8huvSHiuMPc,41
|
134
|
+
continual_foragax-0.29.0.dist-info/top_level.txt,sha256=-z3SDK6RfLIcLI24n8rdbeFzlVY3hunChzlu-v1Fncs,8
|
135
|
+
continual_foragax-0.29.0.dist-info/RECORD,,
|
foragax/env.py
CHANGED
@@ -61,6 +61,7 @@ class EnvState(environment.EnvState):
|
|
61
61
|
object_grid: jax.Array
|
62
62
|
biome_grid: jax.Array
|
63
63
|
time: int
|
64
|
+
digestion_buffer: jax.Array
|
64
65
|
|
65
66
|
|
66
67
|
class ForagaxEnv(environment.Environment):
|
@@ -99,14 +100,9 @@ class ForagaxEnv(environment.Environment):
|
|
99
100
|
self.deterministic_spawn = deterministic_spawn
|
100
101
|
self.teleport_interval = teleport_interval
|
101
102
|
objects = (EMPTY,) + objects
|
102
|
-
if self.nowrap:
|
103
|
+
if self.nowrap and not self.full_world:
|
103
104
|
objects = objects + (PADDING,)
|
104
105
|
self.objects = objects
|
105
|
-
self.weather_object = None
|
106
|
-
for o in objects:
|
107
|
-
if isinstance(o, WeatherObject):
|
108
|
-
self.weather_object = o
|
109
|
-
break
|
110
106
|
|
111
107
|
# JIT-compatible versions of object and biome properties
|
112
108
|
self.object_ids = jnp.arange(len(objects))
|
@@ -117,6 +113,13 @@ class ForagaxEnv(environment.Environment):
|
|
117
113
|
|
118
114
|
self.reward_fns = [o.reward for o in objects]
|
119
115
|
self.regen_delay_fns = [o.regen_delay for o in objects]
|
116
|
+
self.digestion_steps_fns = [o.digestion_steps for o in objects]
|
117
|
+
|
118
|
+
# Compute digestion steps per object (using max_digestion_steps attribute)
|
119
|
+
object_max_digestion_steps = jnp.array([o.max_digestion_steps for o in objects])
|
120
|
+
self.max_digestion_steps = (
|
121
|
+
int(jnp.max(object_max_digestion_steps)) + 1 if len(objects) > 0 else 0
|
122
|
+
)
|
120
123
|
|
121
124
|
self.biome_object_frequencies = jnp.array(
|
122
125
|
[b.object_frequencies for b in biomes]
|
@@ -237,12 +240,36 @@ class ForagaxEnv(environment.Environment):
|
|
237
240
|
|
238
241
|
# 2. HANDLE COLLISIONS AND REWARDS
|
239
242
|
obj_at_pos = current_objects[pos[1], pos[0]]
|
240
|
-
key, subkey = jax.random.split(key)
|
241
|
-
reward = jax.lax.switch(obj_at_pos, self.reward_fns, state.time, subkey)
|
242
243
|
is_collectable = self.object_collectable[obj_at_pos]
|
244
|
+
should_collect = is_collectable & (obj_at_pos > 0)
|
245
|
+
|
246
|
+
# Handle digestion: add reward to buffer if collected
|
247
|
+
digestion_buffer = state.digestion_buffer
|
248
|
+
key, reward_subkey = jax.random.split(key)
|
249
|
+
object_reward = jax.lax.switch(
|
250
|
+
obj_at_pos, self.reward_fns, state.time, reward_subkey
|
251
|
+
)
|
252
|
+
key, digestion_subkey = jax.random.split(key)
|
253
|
+
digestion_steps = jax.lax.switch(
|
254
|
+
obj_at_pos, self.digestion_steps_fns, state.time, digestion_subkey
|
255
|
+
)
|
256
|
+
reward = jnp.where(should_collect & (digestion_steps == 0), object_reward, 0.0)
|
257
|
+
if self.max_digestion_steps > 0:
|
258
|
+
# Add delayed rewards to buffer
|
259
|
+
digestion_buffer = jax.lax.cond(
|
260
|
+
should_collect & (digestion_steps > 0),
|
261
|
+
lambda: digestion_buffer.at[
|
262
|
+
(state.time + digestion_steps) % self.max_digestion_steps
|
263
|
+
].add(object_reward),
|
264
|
+
lambda: digestion_buffer,
|
265
|
+
)
|
266
|
+
# Deliver current rewards
|
267
|
+
current_index = state.time % self.max_digestion_steps
|
268
|
+
reward += digestion_buffer[current_index]
|
269
|
+
digestion_buffer = digestion_buffer.at[current_index].set(0.0)
|
243
270
|
|
244
271
|
# 3. HANDLE OBJECT COLLECTION AND RESPAWNING
|
245
|
-
key,
|
272
|
+
key, regen_subkey, rand_key = jax.random.split(key, 3)
|
246
273
|
|
247
274
|
# Decrement timers (stored as negative values)
|
248
275
|
is_timer = state.object_grid < 0
|
@@ -252,7 +279,7 @@ class ForagaxEnv(environment.Environment):
|
|
252
279
|
|
253
280
|
# Collect object: set a timer
|
254
281
|
regen_delay = jax.lax.switch(
|
255
|
-
obj_at_pos, self.regen_delay_fns, state.time,
|
282
|
+
obj_at_pos, self.regen_delay_fns, state.time, regen_subkey
|
256
283
|
)
|
257
284
|
encoded_timer = obj_at_pos - ((regen_delay + 1) * num_obj_types)
|
258
285
|
|
@@ -301,10 +328,13 @@ class ForagaxEnv(environment.Environment):
|
|
301
328
|
)
|
302
329
|
|
303
330
|
info = {"discount": self.discount(state, params)}
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
331
|
+
temperatures = jnp.zeros(len(self.objects))
|
332
|
+
for obj_index, obj in enumerate(self.objects):
|
333
|
+
if isinstance(obj, WeatherObject):
|
334
|
+
temperatures = temperatures.at[obj_index].set(
|
335
|
+
get_temperature(obj.rewards, state.time, obj.repeat)
|
336
|
+
)
|
337
|
+
info["temperatures"] = temperatures
|
308
338
|
info["biome_id"] = state.biome_grid[pos[1], pos[0]]
|
309
339
|
info["object_collected_id"] = jax.lax.select(should_collect, obj_at_pos, -1)
|
310
340
|
|
@@ -314,6 +344,7 @@ class ForagaxEnv(environment.Environment):
|
|
314
344
|
object_grid=object_grid,
|
315
345
|
biome_grid=state.biome_grid,
|
316
346
|
time=state.time + 1,
|
347
|
+
digestion_buffer=digestion_buffer,
|
317
348
|
)
|
318
349
|
|
319
350
|
done = self.is_terminal(state, params)
|
@@ -352,6 +383,7 @@ class ForagaxEnv(environment.Environment):
|
|
352
383
|
object_grid=object_grid,
|
353
384
|
biome_grid=biome_grid,
|
354
385
|
time=0,
|
386
|
+
digestion_buffer=jnp.zeros((self.max_digestion_steps,)),
|
355
387
|
)
|
356
388
|
|
357
389
|
return self.get_obs(state, params), state
|
@@ -412,6 +444,12 @@ class ForagaxEnv(environment.Environment):
|
|
412
444
|
int,
|
413
445
|
),
|
414
446
|
"time": spaces.Discrete(params.max_steps_in_episode),
|
447
|
+
"digestion_buffer": spaces.Box(
|
448
|
+
-jnp.inf,
|
449
|
+
jnp.inf,
|
450
|
+
(self.max_digestion_steps,),
|
451
|
+
float,
|
452
|
+
),
|
415
453
|
}
|
416
454
|
)
|
417
455
|
|
foragax/objects.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1
1
|
import abc
|
2
|
-
from typing import Tuple
|
2
|
+
from typing import Optional, Tuple
|
3
3
|
|
4
4
|
import jax
|
5
5
|
import jax.numpy as jnp
|
@@ -17,12 +17,14 @@ class BaseForagaxObject:
|
|
17
17
|
collectable: bool = False,
|
18
18
|
color: Tuple[int, int, int] = (0, 0, 0),
|
19
19
|
random_respawn: bool = False,
|
20
|
+
max_digestion_steps: int = 0,
|
20
21
|
):
|
21
22
|
self.name = name
|
22
23
|
self.blocking = blocking
|
23
24
|
self.collectable = collectable
|
24
25
|
self.color = color
|
25
26
|
self.random_respawn = random_respawn
|
27
|
+
self.max_digestion_steps = max_digestion_steps
|
26
28
|
|
27
29
|
@abc.abstractmethod
|
28
30
|
def reward(self, clock: int, rng: jax.Array) -> float:
|
@@ -30,8 +32,8 @@ class BaseForagaxObject:
|
|
30
32
|
raise NotImplementedError
|
31
33
|
|
32
34
|
@abc.abstractmethod
|
33
|
-
def
|
34
|
-
"""
|
35
|
+
def digestion_steps(self, clock: int, rng: jax.Array) -> int:
|
36
|
+
"""Digestion steps function."""
|
35
37
|
raise NotImplementedError
|
36
38
|
|
37
39
|
|
@@ -47,10 +49,17 @@ class DefaultForagaxObject(BaseForagaxObject):
|
|
47
49
|
regen_delay: Tuple[int, int] = (10, 100),
|
48
50
|
color: Tuple[int, int, int] = (255, 255, 255),
|
49
51
|
random_respawn: bool = False,
|
52
|
+
digestion_steps: int = 0,
|
53
|
+
max_digestion_steps: Optional[int] = None,
|
50
54
|
):
|
51
|
-
|
55
|
+
if max_digestion_steps is None:
|
56
|
+
max_digestion_steps = digestion_steps
|
57
|
+
super().__init__(
|
58
|
+
name, blocking, collectable, color, random_respawn, max_digestion_steps
|
59
|
+
)
|
52
60
|
self.reward_val = reward
|
53
61
|
self.regen_delay_range = regen_delay
|
62
|
+
self.digestion_steps_val = digestion_steps
|
54
63
|
|
55
64
|
def reward(self, clock: int, rng: jax.Array) -> float:
|
56
65
|
"""Default reward function."""
|
@@ -61,6 +70,10 @@ class DefaultForagaxObject(BaseForagaxObject):
|
|
61
70
|
min_delay, max_delay = self.regen_delay_range
|
62
71
|
return jax.random.randint(rng, (), min_delay, max_delay)
|
63
72
|
|
73
|
+
def digestion_steps(self, clock: int, rng: jax.Array) -> int:
|
74
|
+
"""Default digestion steps function."""
|
75
|
+
return self.digestion_steps_val
|
76
|
+
|
64
77
|
|
65
78
|
class NormalRegenForagaxObject(DefaultForagaxObject):
|
66
79
|
"""Object with regeneration delay from a normal distribution."""
|
@@ -74,6 +87,8 @@ class NormalRegenForagaxObject(DefaultForagaxObject):
|
|
74
87
|
std_regen_delay: int = 1,
|
75
88
|
color: Tuple[int, int, int] = (0, 0, 0),
|
76
89
|
random_respawn: bool = False,
|
90
|
+
digestion_steps: int = 0,
|
91
|
+
max_digestion_steps: Optional[int] = None,
|
77
92
|
):
|
78
93
|
super().__init__(
|
79
94
|
name=name,
|
@@ -82,6 +97,8 @@ class NormalRegenForagaxObject(DefaultForagaxObject):
|
|
82
97
|
regen_delay=(mean_regen_delay, mean_regen_delay),
|
83
98
|
color=color,
|
84
99
|
random_respawn=random_respawn,
|
100
|
+
digestion_steps=digestion_steps,
|
101
|
+
max_digestion_steps=max_digestion_steps,
|
85
102
|
)
|
86
103
|
self.mean_regen_delay = mean_regen_delay
|
87
104
|
self.std_regen_delay = std_regen_delay
|
@@ -105,6 +122,8 @@ class WeatherObject(NormalRegenForagaxObject):
|
|
105
122
|
std_regen_delay: int = 1,
|
106
123
|
color: Tuple[int, int, int] = (0, 0, 0),
|
107
124
|
random_respawn: bool = False,
|
125
|
+
digestion_steps: int = 0,
|
126
|
+
max_digestion_steps: Optional[int] = None,
|
108
127
|
):
|
109
128
|
super().__init__(
|
110
129
|
name=name,
|
@@ -113,6 +132,8 @@ class WeatherObject(NormalRegenForagaxObject):
|
|
113
132
|
std_regen_delay=std_regen_delay,
|
114
133
|
color=color,
|
115
134
|
random_respawn=random_respawn,
|
135
|
+
digestion_steps=digestion_steps,
|
136
|
+
max_digestion_steps=max_digestion_steps,
|
116
137
|
)
|
117
138
|
self.rewards = rewards
|
118
139
|
self.repeat = repeat
|
@@ -319,6 +340,7 @@ def create_weather_objects(
|
|
319
340
|
multiplier: float = 1.0,
|
320
341
|
same_color: bool = False,
|
321
342
|
random_respawn: bool = False,
|
343
|
+
digestion_steps: int = 0,
|
322
344
|
):
|
323
345
|
"""Create HOT and COLD WeatherObject instances using the specified file.
|
324
346
|
|
@@ -348,6 +370,7 @@ def create_weather_objects(
|
|
348
370
|
multiplier=multiplier,
|
349
371
|
color=hot_color,
|
350
372
|
random_respawn=random_respawn,
|
373
|
+
digestion_steps=digestion_steps,
|
351
374
|
)
|
352
375
|
|
353
376
|
cold_color = hot_color if same_color else (0, 255, 255)
|
@@ -358,6 +381,7 @@ def create_weather_objects(
|
|
358
381
|
multiplier=-multiplier,
|
359
382
|
color=cold_color,
|
360
383
|
random_respawn=random_respawn,
|
384
|
+
digestion_steps=digestion_steps,
|
361
385
|
)
|
362
386
|
|
363
387
|
return hot, cold
|
foragax/registry.py
CHANGED
@@ -83,6 +83,19 @@ ENV_CONFIGS: Dict[str, Dict[str, Any]] = {
|
|
83
83
|
"nowrap": False,
|
84
84
|
"deterministic_spawn": True,
|
85
85
|
},
|
86
|
+
"ForagaxWeather-v6": {
|
87
|
+
"size": (15, 15),
|
88
|
+
"aperture_size": None,
|
89
|
+
"objects": None,
|
90
|
+
"biomes": (
|
91
|
+
# Hot biome
|
92
|
+
Biome(start=(0, 3), stop=(15, 5), object_frequencies=(0.5, 0.0)),
|
93
|
+
# Cold biome
|
94
|
+
Biome(start=(0, 10), stop=(15, 12), object_frequencies=(0.0, 0.5)),
|
95
|
+
),
|
96
|
+
"nowrap": False,
|
97
|
+
"deterministic_spawn": True,
|
98
|
+
},
|
86
99
|
"ForagaxTwoBiome-v1": {
|
87
100
|
"size": (15, 15),
|
88
101
|
"aperture_size": None,
|
@@ -348,7 +361,6 @@ def make(
|
|
348
361
|
observation_type: str = "color",
|
349
362
|
aperture_size: Optional[Tuple[int, int]] = (5, 5),
|
350
363
|
file_index: int = 0,
|
351
|
-
nowrap: Optional[bool] = None,
|
352
364
|
**kwargs: Any,
|
353
365
|
) -> ForagaxEnv:
|
354
366
|
"""Create a Foragax environment.
|
@@ -358,9 +370,7 @@ def make(
|
|
358
370
|
observation_type: The type of observation to use. One of "object", "rgb", or "color".
|
359
371
|
aperture_size: The size of the agent's observation aperture. If -1, full world observation.
|
360
372
|
If None, the default for the environment is used.
|
361
|
-
file_index: File index for weather objects.
|
362
|
-
wrapping around environment boundaries. If None, uses defaults per
|
363
|
-
environment.
|
373
|
+
file_index: File index for weather objects.
|
364
374
|
**kwargs: Additional keyword arguments to pass to the ForagaxEnv constructor.
|
365
375
|
|
366
376
|
Returns:
|
@@ -376,8 +386,6 @@ def make(
|
|
376
386
|
else:
|
377
387
|
aperture_size = (aperture_size, aperture_size)
|
378
388
|
config["aperture_size"] = aperture_size
|
379
|
-
if nowrap is not None:
|
380
|
-
config["nowrap"] = nowrap
|
381
389
|
|
382
390
|
# Handle special size and biome configurations
|
383
391
|
if env_id in (
|
@@ -459,10 +467,19 @@ def make(
|
|
459
467
|
"ForagaxWeather-v3",
|
460
468
|
"ForagaxWeather-v4",
|
461
469
|
"ForagaxWeather-v5",
|
470
|
+
"ForagaxWeather-v6",
|
471
|
+
)
|
472
|
+
random_respawn = env_id in (
|
473
|
+
"ForagaxWeather-v4",
|
474
|
+
"ForagaxWeather-v5",
|
475
|
+
"ForagaxWeather-v6",
|
462
476
|
)
|
463
|
-
|
477
|
+
digestion_steps = 10 if env_id in ("ForagaxWeather-v6") else 0
|
464
478
|
hot, cold = create_weather_objects(
|
465
|
-
file_index=file_index,
|
479
|
+
file_index=file_index,
|
480
|
+
same_color=same_color,
|
481
|
+
random_respawn=random_respawn,
|
482
|
+
digestion_steps=digestion_steps,
|
466
483
|
)
|
467
484
|
config["objects"] = (hot, cold)
|
468
485
|
|
File without changes
|
File without changes
|
File without changes
|