continual-foragax 0.38.0__py3-none-any.whl → 0.40.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: continual-foragax
3
- Version: 0.38.0
3
+ Version: 0.40.0
4
4
  Summary: A continual reinforcement learning benchmark
5
5
  Author-email: Steven Tang <stang5@ualberta.ca>
6
6
  Requires-Python: >=3.8
@@ -1,8 +1,8 @@
1
1
  foragax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  foragax/colors.py,sha256=rqNPiywP4Nvr0POhsGpasRk-nMMTS3DOwFRUgperlUk,2065
3
- foragax/env.py,sha256=9EB02GgSc80KRjj35z_UyM4VXTGuyHBAHX3d9M5AvY0,66966
3
+ foragax/env.py,sha256=-lBBh70JRInOFR6mmU8uyA9nzgIIPapIdY20y4-vwQU,67848
4
4
  foragax/objects.py,sha256=aVc7lD3CTyRP9wm_Vs93qo4l_B1kbiYGKPtkd_SVXjs,27061
5
- foragax/registry.py,sha256=u8mv5mPTHKRsIE7FTi36J5tXp3FuVJ-AmJRb7HWW9hI,20636
5
+ foragax/registry.py,sha256=2qi7Dq96RZVyBQwAnejdBztQWJOQHNAw26UUPBTzIEY,20551
6
6
  foragax/rendering.py,sha256=bms7wvBZTofoR-K-2QD2Ggeed7Viw8uwAEiEpEM3eSo,2768
7
7
  foragax/weather.py,sha256=KNAiwuFz8V__6G75vZIWQKPocLzXqxXn-Vt4TbHIpcA,1258
8
8
  foragax/data/ECA_non-blended_custom/TG_SOUID100897.txt,sha256=N7URbX6VlCZvCboUogYjMzy1I-0cfNPOn0QTLSHHfQ0,1776751
@@ -128,8 +128,8 @@ foragax/data/ECA_non-blended_custom/TG_SOUID156887.txt,sha256=juzTPgJoJxfqmZkorL
128
128
  foragax/data/ECA_non-blended_custom/elements.txt,sha256=OtcUBoDAHxuln79BPKGu0tsQxG_5G2BfAX3Ck130kEA,4507
129
129
  foragax/data/ECA_non-blended_custom/metadata.txt,sha256=nudnmOCy5cPJfSXt_IjyX0S5-T7NkCZREICZSimqeqc,48260
130
130
  foragax/data/ECA_non-blended_custom/sources.txt,sha256=1j3lSmINAoCMqPqFrHfZJriOz6sTYZNOhXzUwvTLas0,20857
131
- continual_foragax-0.38.0.dist-info/METADATA,sha256=ODZtbDZHv2GRj0yRC4rR34Hu0PaN_8e4Oo8XPmhWGxU,4713
132
- continual_foragax-0.38.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
133
- continual_foragax-0.38.0.dist-info/entry_points.txt,sha256=Qiu6iE_XudrDO_bVAMeA435h4PO9ourt8huvSHiuMPc,41
134
- continual_foragax-0.38.0.dist-info/top_level.txt,sha256=-z3SDK6RfLIcLI24n8rdbeFzlVY3hunChzlu-v1Fncs,8
135
- continual_foragax-0.38.0.dist-info/RECORD,,
131
+ continual_foragax-0.40.0.dist-info/METADATA,sha256=DN8INomefM2v426Q75KtXL-VqD_fsFKAjQMaphc9Rl4,4713
132
+ continual_foragax-0.40.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
133
+ continual_foragax-0.40.0.dist-info/entry_points.txt,sha256=Qiu6iE_XudrDO_bVAMeA435h4PO9ourt8huvSHiuMPc,41
134
+ continual_foragax-0.40.0.dist-info/top_level.txt,sha256=-z3SDK6RfLIcLI24n8rdbeFzlVY3hunChzlu-v1Fncs,8
135
+ continual_foragax-0.40.0.dist-info/RECORD,,
foragax/env.py CHANGED
@@ -6,7 +6,7 @@ Source: https://github.com/andnp/Forager
6
6
  from dataclasses import dataclass
7
7
  from enum import IntEnum
8
8
  from functools import partial
9
- from typing import Any, Dict, Optional, Tuple, Union
9
+ from typing import Any, Dict, Tuple, Union
10
10
 
11
11
  import jax
12
12
  import jax.numpy as jnp
@@ -128,7 +128,6 @@ class ForagaxEnv(environment.Environment):
128
128
  biomes: Tuple[Biome, ...] = (Biome(object_frequencies=()),),
129
129
  nowrap: bool = False,
130
130
  deterministic_spawn: bool = False,
131
- teleport_interval: Optional[int] = None,
132
131
  observation_type: str = "object",
133
132
  dynamic_biomes: bool = False,
134
133
  biome_consumption_threshold: float = 0.9,
@@ -154,7 +153,6 @@ class ForagaxEnv(environment.Environment):
154
153
  self.observation_type = observation_type
155
154
  self.nowrap = nowrap
156
155
  self.deterministic_spawn = deterministic_spawn
157
- self.teleport_interval = teleport_interval
158
156
  self.dynamic_biomes = dynamic_biomes
159
157
  self.biome_consumption_threshold = biome_consumption_threshold
160
158
  self.dynamic_biome_spawn_empty = dynamic_biome_spawn_empty
@@ -391,23 +389,6 @@ class ForagaxEnv(environment.Environment):
391
389
  is_blocking = self.object_blocking[obj_at_new_pos]
392
390
  pos = jax.lax.select(is_blocking, state.pos, new_pos)
393
391
 
394
- # Check for automatic teleport
395
- if self.teleport_interval is not None:
396
- should_teleport = jnp.mod(state.time + 1, self.teleport_interval) == 0
397
- else:
398
- should_teleport = False
399
-
400
- def teleport_fn():
401
- # Calculate squared distances from current position to each biome center
402
- diffs = self.biome_centers_jax - pos
403
- distances = jnp.sum(diffs**2, axis=1)
404
- # Find the index of the furthest biome center
405
- furthest_idx = jnp.argmax(distances)
406
- new_pos = self.biome_centers_jax[furthest_idx]
407
- return new_pos
408
-
409
- pos = jax.lax.cond(should_teleport, teleport_fn, lambda: pos)
410
-
411
392
  # 2. HANDLE COLLISIONS AND REWARDS
412
393
  obj_at_pos = current_objects[pos[1], pos[0]]
413
394
  is_collectable = self.object_collectable[obj_at_pos]
@@ -774,20 +755,53 @@ class ForagaxEnv(environment.Environment):
774
755
  biome_mask = self.biome_masks_array[i]
775
756
  new_gen_value = new_biome_state.generation[i]
776
757
 
777
- # Only update where new spawn has objects and biome should respawn
778
- is_new_object = (
779
- (all_new_objects[i] > 0) & biome_mask & should_respawn[i][..., None]
780
- )
758
+ # Update mask: biome area AND needs respawn
759
+ should_update = biome_mask & should_respawn[i][..., None]
760
+
761
+ # 1. Merge: Overwrite with new objects if present, otherwise keep existing
762
+ # If the new spawn has an object, we take it. If it's empty, we keep whatever was there.
763
+ new_spawn_valid = all_new_objects[i] > 0
781
764
 
782
- new_obj_id = jnp.where(is_new_object, all_new_objects[i], new_obj_id)
783
- new_color = jnp.where(
784
- is_new_object[..., None], all_new_colors[i], new_color
765
+ merged_objs = jnp.where(new_spawn_valid, all_new_objects[i], new_obj_id)
766
+ merged_colors = jnp.where(
767
+ new_spawn_valid[..., None], all_new_colors[i], new_color
785
768
  )
786
- new_params = jnp.where(
787
- is_new_object[..., None], all_new_params[i], new_params
769
+ merged_params = jnp.where(
770
+ new_spawn_valid[..., None], all_new_params[i], new_params
788
771
  )
789
- new_gen = jnp.where(is_new_object, new_gen_value, new_gen)
790
- new_spawn = jnp.where(is_new_object, current_time, new_spawn)
772
+
773
+ # For generation/spawn time, update only if we took the NEW object
774
+ merged_gen = jnp.where(new_spawn_valid, new_gen_value, new_gen)
775
+ merged_spawn = jnp.where(new_spawn_valid, current_time, new_spawn)
776
+
777
+ # 2. Apply dropout to the MERGED result (only where we are allowed to update)
778
+ if self.dynamic_biome_spawn_empty > 0:
779
+ key, dropout_key = jax.random.split(key)
780
+ # Dropout applies only to cells that are both in the biome AND need updating
781
+ keep_mask = jax.random.bernoulli(
782
+ dropout_key, 1.0 - self.dynamic_biome_spawn_empty, merged_objs.shape
783
+ )
784
+ # Only apply dropout where should_update is true; elsewhere, keep merged_objs
785
+ dropout_mask = should_update & keep_mask
786
+ # Apply dropout only to the merged result and associated metadata
787
+ final_objs = jnp.where(dropout_mask, merged_objs, 0)
788
+ final_colors = jnp.where(dropout_mask[..., None], merged_colors, 0)
789
+ final_params = jnp.where(dropout_mask[..., None], merged_params, 0)
790
+ final_gen = jnp.where(dropout_mask, merged_gen, 0)
791
+ final_spawn = jnp.where(dropout_mask, merged_spawn, 0)
792
+ else:
793
+ final_objs = merged_objs
794
+ final_colors = merged_colors
795
+ final_params = merged_params
796
+ final_gen = merged_gen
797
+ final_spawn = merged_spawn
798
+
799
+ # 3. Write back: Only update where should_update is true
800
+ new_obj_id = jnp.where(should_update, final_objs, new_obj_id)
801
+ new_color = jnp.where(should_update[..., None], final_colors, new_color)
802
+ new_params = jnp.where(should_update[..., None], final_params, new_params)
803
+ new_gen = jnp.where(should_update, final_gen, new_gen)
804
+ new_spawn = jnp.where(should_update, final_spawn, new_spawn)
791
805
 
792
806
  # Clear timers in respawning biomes
793
807
  new_respawn_timer = object_state.respawn_timer
@@ -890,7 +904,7 @@ class ForagaxEnv(environment.Environment):
890
904
  biome_freqs = self.biome_object_frequencies[biome_idx]
891
905
  biome_mask = self.biome_masks_array[biome_idx]
892
906
 
893
- key, spawn_key, color_key, params_key, dropout_key = jax.random.split(key, 5)
907
+ key, spawn_key, color_key, params_key = jax.random.split(key, 4)
894
908
 
895
909
  # Generate object IDs using deterministic or random spawn
896
910
  if deterministic:
@@ -929,13 +943,6 @@ class ForagaxEnv(environment.Environment):
929
943
  jnp.searchsorted(cumulative_freqs, grid_rand, side="right") - 1
930
944
  )
931
945
 
932
- # Apply random dropout if configured
933
- if self.dynamic_biome_spawn_empty > 0:
934
- dropout_mask = jax.random.bernoulli(
935
- dropout_key, 1.0 - self.dynamic_biome_spawn_empty, object_grid.shape
936
- )
937
- object_grid = jnp.where(dropout_mask, object_grid, 0)
938
-
939
946
  # Initialize color grid
940
947
  color_grid = jnp.full((self.size[1], self.size[0], 3), 255, dtype=jnp.uint8)
941
948
 
foragax/registry.py CHANGED
@@ -147,12 +147,12 @@ ENV_CONFIGS: Dict[str, Dict[str, Any]] = {
147
147
  "size": (15, 15),
148
148
  "aperture_size": None,
149
149
  "objects": None,
150
- "biomes": (Biome(start=(0, 0), stop=(15, 15), object_frequencies=(0.2,)),),
150
+ "biomes": (Biome(start=(0, 0), stop=(15, 15), object_frequencies=(0.4,)),),
151
151
  "nowrap": False,
152
152
  "deterministic_spawn": True,
153
153
  "dynamic_biomes": True,
154
154
  "biome_consumption_threshold": 1000,
155
- "dynamic_biome_spawn_empty": 0.1,
155
+ "dynamic_biome_spawn_empty": 0.4,
156
156
  },
157
157
  "ForagaxTwoBiome-v1": {
158
158
  "size": (15, 15),
@@ -636,9 +636,6 @@ def make(
636
636
  biome2_deathcap,
637
637
  )
638
638
 
639
- if env_id == "ForagaxTwoBiome-v16":
640
- config["teleport_interval"] = 10000
641
-
642
639
  # Backward compatibility: map "world" to "object" with full world
643
640
  if observation_type == "world":
644
641
  # add deprecation warning