continual-foragax 0.8.2__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: continual-foragax
3
- Version: 0.8.2
3
+ Version: 0.9.0
4
4
  Summary: A continual reinforcement learning benchmark
5
5
  Author-email: Steven Tang <stang5@ualberta.ca>
6
6
  Requires-Python: >=3.8
@@ -119,3 +119,12 @@ class into registry configs or construct environments programmatically.
119
119
  ## Development
120
120
 
121
121
  Run unit tests via pytest.
122
+
123
+ ## Acknowledgments
124
+
125
+ We acknowledge the data providers in the ECA&D project. Klein Tank, A.M.G. and
126
+ Coauthors, 2002. Daily dataset of 20th-century surface air temperature and
127
+ precipitation series for the European Climate Assessment. Int. J. of Climatol.,
128
+ 22, 1441-1453.
129
+
130
+ Data and metadata available at https://www.ecad.eu
@@ -1,5 +1,5 @@
1
1
  foragax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- foragax/env.py,sha256=mT9qfRfmzhFntJ1KKsaP4gKRNqijsJ4hh2sgbiyRBgM,15659
2
+ foragax/env.py,sha256=5IJtONEGbW96bRCKxMy9efim-e8VGI1Ab4juZX3e0PY,17613
3
3
  foragax/objects.py,sha256=lR6QnWX8xBK-EA91BhJGjfA8MNmUg_RXBZO2KAoAWzE,6004
4
4
  foragax/registry.py,sha256=7_RDXvm_3RNO7culBLGkE0jH8Wk_q6jbMv72dZx4JO8,2722
5
5
  foragax/weather.py,sha256=KNAiwuFz8V__6G75vZIWQKPocLzXqxXn-Vt4TbHIpcA,1258
@@ -126,8 +126,8 @@ foragax/data/ECA_non-blended_custom/TG_SOUID156887.txt,sha256=juzTPgJoJxfqmZkorL
126
126
  foragax/data/ECA_non-blended_custom/elements.txt,sha256=OtcUBoDAHxuln79BPKGu0tsQxG_5G2BfAX3Ck130kEA,4507
127
127
  foragax/data/ECA_non-blended_custom/metadata.txt,sha256=nudnmOCy5cPJfSXt_IjyX0S5-T7NkCZREICZSimqeqc,48260
128
128
  foragax/data/ECA_non-blended_custom/sources.txt,sha256=1j3lSmINAoCMqPqFrHfZJriOz6sTYZNOhXzUwvTLas0,20857
129
- continual_foragax-0.8.2.dist-info/METADATA,sha256=RlHNKYqScJI1DV7QEUrkZHKtsCr4vgea5yDzo9PZ1D8,4574
130
- continual_foragax-0.8.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
131
- continual_foragax-0.8.2.dist-info/entry_points.txt,sha256=Qiu6iE_XudrDO_bVAMeA435h4PO9ourt8huvSHiuMPc,41
132
- continual_foragax-0.8.2.dist-info/top_level.txt,sha256=-z3SDK6RfLIcLI24n8rdbeFzlVY3hunChzlu-v1Fncs,8
133
- continual_foragax-0.8.2.dist-info/RECORD,,
129
+ continual_foragax-0.9.0.dist-info/METADATA,sha256=2-IHrNN5vXoTGo557tCFRfD-49j9xf445rHn3a4My_c,4896
130
+ continual_foragax-0.9.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
131
+ continual_foragax-0.9.0.dist-info/entry_points.txt,sha256=Qiu6iE_XudrDO_bVAMeA435h4PO9ourt8huvSHiuMPc,41
132
+ continual_foragax-0.9.0.dist-info/top_level.txt,sha256=-z3SDK6RfLIcLI24n8rdbeFzlVY3hunChzlu-v1Fncs,8
133
+ continual_foragax-0.9.0.dist-info/RECORD,,
foragax/env.py CHANGED
@@ -380,22 +380,69 @@ class ForagaxEnv(environment.Environment):
380
380
  class ForagaxObjectEnv(ForagaxEnv):
381
381
  """Foragax environment with object-based aperture observation."""
382
382
 
383
+ def __init__(
384
+ self,
385
+ size: Union[Tuple[int, int], int] = (10, 10),
386
+ aperture_size: Union[Tuple[int, int], int] = (5, 5),
387
+ objects: Tuple[BaseForagaxObject, ...] = (),
388
+ biomes: Tuple[Biome, ...] = (Biome(object_frequencies=()),),
389
+ ):
390
+ super().__init__(size, aperture_size, objects, biomes)
391
+
392
+ # Compute unique colors and mapping for partial observability
393
+ # Exclude EMPTY (index 0) from color channels
394
+ object_colors_no_empty = self.object_colors[1:]
395
+
396
+ # Find unique colors in order of first appearance
397
+ unique_colors = []
398
+ color_indices = jnp.zeros(len(object_colors_no_empty), dtype=int)
399
+ color_map = {}
400
+ next_channel = 0
401
+
402
+ for i, color in enumerate(object_colors_no_empty):
403
+ color_tuple = tuple(color.tolist())
404
+ if color_tuple not in color_map:
405
+ color_map[color_tuple] = next_channel
406
+ unique_colors.append(color)
407
+ next_channel += 1
408
+ color_indices = color_indices.at[i].set(color_map[color_tuple])
409
+
410
+ self.unique_colors = jnp.array(unique_colors)
411
+ self.num_color_channels = len(unique_colors)
412
+ # color_indices maps from object_id-1 to color_channel_index
413
+ self.object_to_color_map = color_indices
414
+
383
415
  def get_obs(self, state: EnvState, params: EnvParams, key=None) -> jax.Array:
384
- num_obj_types = len(self.object_ids)
385
416
  # Decode grid for observation
386
417
  obs_grid = jnp.maximum(0, state.object_grid)
387
418
  aperture = self._get_aperture(obs_grid, state.pos)
388
419
  aperture = jnp.flip(aperture, axis=0)
389
- obs = jax.nn.one_hot(aperture, num_obj_types)
390
- obs = obs[:, :, 1:]
420
+
421
+ # Handle case with no objects (only EMPTY)
422
+ if self.num_color_channels == 0:
423
+ return jnp.zeros(aperture.shape + (0,), dtype=jnp.float32)
424
+
425
+ # Map object IDs to color channel indices
426
+ # aperture contains object IDs (0 = EMPTY, 1+ = objects)
427
+ # For EMPTY (0), we want no color channel activated
428
+ # For objects (1+), map to color channel using object_to_color_map
429
+ color_channels = jnp.where(
430
+ aperture == 0,
431
+ -1, # Special value for EMPTY
432
+ jnp.take(self.object_to_color_map, aperture - 1, axis=0),
433
+ )
434
+
435
+ # Create one-hot encoding for color channels
436
+ # jax.nn.one_hot produces all zeros for -1 (EMPTY positions)
437
+ obs = jax.nn.one_hot(color_channels, self.num_color_channels, axis=-1)
438
+
391
439
  return obs
392
440
 
393
441
  def observation_space(self, params: EnvParams) -> spaces.Box:
394
- num_obj_types = len(self.object_ids)
395
442
  obs_shape = (
396
443
  self.aperture_size[0],
397
444
  self.aperture_size[1],
398
- num_obj_types - 1,
445
+ self.num_color_channels,
399
446
  )
400
447
  return spaces.Box(0, 1, obs_shape, float)
401
448