continual-foragax 0.8.2__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {continual_foragax-0.8.2.dist-info → continual_foragax-0.9.0.dist-info}/METADATA +10 -1
- {continual_foragax-0.8.2.dist-info → continual_foragax-0.9.0.dist-info}/RECORD +6 -6
- foragax/env.py +52 -5
- {continual_foragax-0.8.2.dist-info → continual_foragax-0.9.0.dist-info}/WHEEL +0 -0
- {continual_foragax-0.8.2.dist-info → continual_foragax-0.9.0.dist-info}/entry_points.txt +0 -0
- {continual_foragax-0.8.2.dist-info → continual_foragax-0.9.0.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: continual-foragax
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.9.0
|
4
4
|
Summary: A continual reinforcement learning benchmark
|
5
5
|
Author-email: Steven Tang <stang5@ualberta.ca>
|
6
6
|
Requires-Python: >=3.8
|
@@ -119,3 +119,12 @@ class into registry configs or construct environments programmatically.
|
|
119
119
|
## Development
|
120
120
|
|
121
121
|
Run unit tests via pytest.
|
122
|
+
|
123
|
+
## Acknowledgments
|
124
|
+
|
125
|
+
We acknowledge the data providers in the ECA&D project. Klein Tank, A.M.G. and
|
126
|
+
Coauthors, 2002. Daily dataset of 20th-century surface air temperature and
|
127
|
+
precipitation series for the European Climate Assessment. Int. J. of Climatol.,
|
128
|
+
22, 1441-1453.
|
129
|
+
|
130
|
+
Data and metadata available at https://www.ecad.eu
|
@@ -1,5 +1,5 @@
|
|
1
1
|
foragax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
-
foragax/env.py,sha256=
|
2
|
+
foragax/env.py,sha256=5IJtONEGbW96bRCKxMy9efim-e8VGI1Ab4juZX3e0PY,17613
|
3
3
|
foragax/objects.py,sha256=lR6QnWX8xBK-EA91BhJGjfA8MNmUg_RXBZO2KAoAWzE,6004
|
4
4
|
foragax/registry.py,sha256=7_RDXvm_3RNO7culBLGkE0jH8Wk_q6jbMv72dZx4JO8,2722
|
5
5
|
foragax/weather.py,sha256=KNAiwuFz8V__6G75vZIWQKPocLzXqxXn-Vt4TbHIpcA,1258
|
@@ -126,8 +126,8 @@ foragax/data/ECA_non-blended_custom/TG_SOUID156887.txt,sha256=juzTPgJoJxfqmZkorL
|
|
126
126
|
foragax/data/ECA_non-blended_custom/elements.txt,sha256=OtcUBoDAHxuln79BPKGu0tsQxG_5G2BfAX3Ck130kEA,4507
|
127
127
|
foragax/data/ECA_non-blended_custom/metadata.txt,sha256=nudnmOCy5cPJfSXt_IjyX0S5-T7NkCZREICZSimqeqc,48260
|
128
128
|
foragax/data/ECA_non-blended_custom/sources.txt,sha256=1j3lSmINAoCMqPqFrHfZJriOz6sTYZNOhXzUwvTLas0,20857
|
129
|
-
continual_foragax-0.
|
130
|
-
continual_foragax-0.
|
131
|
-
continual_foragax-0.
|
132
|
-
continual_foragax-0.
|
133
|
-
continual_foragax-0.
|
129
|
+
continual_foragax-0.9.0.dist-info/METADATA,sha256=2-IHrNN5vXoTGo557tCFRfD-49j9xf445rHn3a4My_c,4896
|
130
|
+
continual_foragax-0.9.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
131
|
+
continual_foragax-0.9.0.dist-info/entry_points.txt,sha256=Qiu6iE_XudrDO_bVAMeA435h4PO9ourt8huvSHiuMPc,41
|
132
|
+
continual_foragax-0.9.0.dist-info/top_level.txt,sha256=-z3SDK6RfLIcLI24n8rdbeFzlVY3hunChzlu-v1Fncs,8
|
133
|
+
continual_foragax-0.9.0.dist-info/RECORD,,
|
foragax/env.py
CHANGED
@@ -380,22 +380,69 @@ class ForagaxEnv(environment.Environment):
|
|
380
380
|
class ForagaxObjectEnv(ForagaxEnv):
|
381
381
|
"""Foragax environment with object-based aperture observation."""
|
382
382
|
|
383
|
+
def __init__(
|
384
|
+
self,
|
385
|
+
size: Union[Tuple[int, int], int] = (10, 10),
|
386
|
+
aperture_size: Union[Tuple[int, int], int] = (5, 5),
|
387
|
+
objects: Tuple[BaseForagaxObject, ...] = (),
|
388
|
+
biomes: Tuple[Biome, ...] = (Biome(object_frequencies=()),),
|
389
|
+
):
|
390
|
+
super().__init__(size, aperture_size, objects, biomes)
|
391
|
+
|
392
|
+
# Compute unique colors and mapping for partial observability
|
393
|
+
# Exclude EMPTY (index 0) from color channels
|
394
|
+
object_colors_no_empty = self.object_colors[1:]
|
395
|
+
|
396
|
+
# Find unique colors in order of first appearance
|
397
|
+
unique_colors = []
|
398
|
+
color_indices = jnp.zeros(len(object_colors_no_empty), dtype=int)
|
399
|
+
color_map = {}
|
400
|
+
next_channel = 0
|
401
|
+
|
402
|
+
for i, color in enumerate(object_colors_no_empty):
|
403
|
+
color_tuple = tuple(color.tolist())
|
404
|
+
if color_tuple not in color_map:
|
405
|
+
color_map[color_tuple] = next_channel
|
406
|
+
unique_colors.append(color)
|
407
|
+
next_channel += 1
|
408
|
+
color_indices = color_indices.at[i].set(color_map[color_tuple])
|
409
|
+
|
410
|
+
self.unique_colors = jnp.array(unique_colors)
|
411
|
+
self.num_color_channels = len(unique_colors)
|
412
|
+
# color_indices maps from object_id-1 to color_channel_index
|
413
|
+
self.object_to_color_map = color_indices
|
414
|
+
|
383
415
|
def get_obs(self, state: EnvState, params: EnvParams, key=None) -> jax.Array:
|
384
|
-
num_obj_types = len(self.object_ids)
|
385
416
|
# Decode grid for observation
|
386
417
|
obs_grid = jnp.maximum(0, state.object_grid)
|
387
418
|
aperture = self._get_aperture(obs_grid, state.pos)
|
388
419
|
aperture = jnp.flip(aperture, axis=0)
|
389
|
-
|
390
|
-
|
420
|
+
|
421
|
+
# Handle case with no objects (only EMPTY)
|
422
|
+
if self.num_color_channels == 0:
|
423
|
+
return jnp.zeros(aperture.shape + (0,), dtype=jnp.float32)
|
424
|
+
|
425
|
+
# Map object IDs to color channel indices
|
426
|
+
# aperture contains object IDs (0 = EMPTY, 1+ = objects)
|
427
|
+
# For EMPTY (0), we want no color channel activated
|
428
|
+
# For objects (1+), map to color channel using object_to_color_map
|
429
|
+
color_channels = jnp.where(
|
430
|
+
aperture == 0,
|
431
|
+
-1, # Special value for EMPTY
|
432
|
+
jnp.take(self.object_to_color_map, aperture - 1, axis=0),
|
433
|
+
)
|
434
|
+
|
435
|
+
# Create one-hot encoding for color channels
|
436
|
+
# jax.nn.one_hot produces all zeros for -1 (EMPTY positions)
|
437
|
+
obs = jax.nn.one_hot(color_channels, self.num_color_channels, axis=-1)
|
438
|
+
|
391
439
|
return obs
|
392
440
|
|
393
441
|
def observation_space(self, params: EnvParams) -> spaces.Box:
|
394
|
-
num_obj_types = len(self.object_ids)
|
395
442
|
obs_shape = (
|
396
443
|
self.aperture_size[0],
|
397
444
|
self.aperture_size[1],
|
398
|
-
|
445
|
+
self.num_color_channels,
|
399
446
|
)
|
400
447
|
return spaces.Box(0, 1, obs_shape, float)
|
401
448
|
|
File without changes
|
File without changes
|
File without changes
|