continual-foragax 0.36.0__py3-none-any.whl → 0.37.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: continual-foragax
3
- Version: 0.36.0
3
+ Version: 0.37.0
4
4
  Summary: A continual reinforcement learning benchmark
5
5
  Author-email: Steven Tang <stang5@ualberta.ca>
6
6
  Requires-Python: >=3.8
@@ -1,8 +1,8 @@
1
1
  foragax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  foragax/colors.py,sha256=rqNPiywP4Nvr0POhsGpasRk-nMMTS3DOwFRUgperlUk,2065
3
- foragax/env.py,sha256=K3noPwdYmQlnXVjslqVzX_FIB-CnOh37mWFArQXnf_Y,66324
4
- foragax/objects.py,sha256=PPuLYjD7em7GL404eSpP6q8TxF8p7JtQ1kIwh7uD_tU,26860
5
- foragax/registry.py,sha256=y_MEvM0K_E4fzdp4XokzLh1D94UAuCu16wGwJV4wa_A,19289
3
+ foragax/env.py,sha256=7IhLCEosM_IH19sMzRnGMvTNq9tlhVg4PU52hGz9XLE,66531
4
+ foragax/objects.py,sha256=aVc7lD3CTyRP9wm_Vs93qo4l_B1kbiYGKPtkd_SVXjs,27061
5
+ foragax/registry.py,sha256=dTdRwNNW8jfeDxioYCx4-GtL-J8t1nGwt7gDwDvN5TY,20016
6
6
  foragax/rendering.py,sha256=bms7wvBZTofoR-K-2QD2Ggeed7Viw8uwAEiEpEM3eSo,2768
7
7
  foragax/weather.py,sha256=KNAiwuFz8V__6G75vZIWQKPocLzXqxXn-Vt4TbHIpcA,1258
8
8
  foragax/data/ECA_non-blended_custom/TG_SOUID100897.txt,sha256=N7URbX6VlCZvCboUogYjMzy1I-0cfNPOn0QTLSHHfQ0,1776751
@@ -128,8 +128,8 @@ foragax/data/ECA_non-blended_custom/TG_SOUID156887.txt,sha256=juzTPgJoJxfqmZkorL
128
128
  foragax/data/ECA_non-blended_custom/elements.txt,sha256=OtcUBoDAHxuln79BPKGu0tsQxG_5G2BfAX3Ck130kEA,4507
129
129
  foragax/data/ECA_non-blended_custom/metadata.txt,sha256=nudnmOCy5cPJfSXt_IjyX0S5-T7NkCZREICZSimqeqc,48260
130
130
  foragax/data/ECA_non-blended_custom/sources.txt,sha256=1j3lSmINAoCMqPqFrHfZJriOz6sTYZNOhXzUwvTLas0,20857
131
- continual_foragax-0.36.0.dist-info/METADATA,sha256=FZVOu8G1yekPfxV7NxyWBXWRdTECWYB3gh08v2o4NIQ,4713
132
- continual_foragax-0.36.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
133
- continual_foragax-0.36.0.dist-info/entry_points.txt,sha256=Qiu6iE_XudrDO_bVAMeA435h4PO9ourt8huvSHiuMPc,41
134
- continual_foragax-0.36.0.dist-info/top_level.txt,sha256=-z3SDK6RfLIcLI24n8rdbeFzlVY3hunChzlu-v1Fncs,8
135
- continual_foragax-0.36.0.dist-info/RECORD,,
131
+ continual_foragax-0.37.0.dist-info/METADATA,sha256=fWAIh_Yq86ibXEtNLXKrqjcgxtoMPS0Em_q_UU8XZ8Q,4713
132
+ continual_foragax-0.37.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
133
+ continual_foragax-0.37.0.dist-info/entry_points.txt,sha256=Qiu6iE_XudrDO_bVAMeA435h4PO9ourt8huvSHiuMPc,41
134
+ continual_foragax-0.37.0.dist-info/top_level.txt,sha256=-z3SDK6RfLIcLI24n8rdbeFzlVY3hunChzlu-v1Fncs,8
135
+ continual_foragax-0.37.0.dist-info/RECORD,,
foragax/env.py CHANGED
@@ -132,6 +132,7 @@ class ForagaxEnv(environment.Environment):
132
132
  observation_type: str = "object",
133
133
  dynamic_biomes: bool = False,
134
134
  biome_consumption_threshold: float = 0.9,
135
+ max_expiries_per_step: int = 1,
135
136
  ):
136
137
  super().__init__()
137
138
  self._name = name
@@ -155,6 +156,9 @@ class ForagaxEnv(environment.Environment):
155
156
  self.teleport_interval = teleport_interval
156
157
  self.dynamic_biomes = dynamic_biomes
157
158
  self.biome_consumption_threshold = biome_consumption_threshold
159
+ if max_expiries_per_step < 1:
160
+ raise ValueError("max_expiries_per_step must be at least 1")
161
+ self.max_expiries_per_step = max_expiries_per_step
158
162
 
159
163
  objects = (EMPTY,) + objects
160
164
  if self.nowrap and not self.full_world:
@@ -608,63 +612,63 @@ class ForagaxEnv(environment.Environment):
608
612
  & (current_objects_for_expiry > 0)
609
613
  )
610
614
 
611
- # Count how many objects actually need to expire
612
- num_expiring = jnp.sum(should_expire)
613
-
614
615
  # Only process expiry if there are actually objects to expire
615
- def process_expiries():
616
- # Get positions of objects that should expire
617
- # Use nonzero with fixed size to maintain JIT compatibility
618
- max_objects = self.size[0] * self.size[1]
619
- y_indices, x_indices = jnp.nonzero(
620
- should_expire, size=max_objects, fill_value=-1
621
- )
616
+ has_expiring = jnp.any(should_expire)
617
+
618
+ # Precompute the first expiring index in flat space so the work inside cond is minimal.
619
+ overage = jnp.where(
620
+ should_expire,
621
+ object_ages - expiry_times,
622
+ -jnp.inf,
623
+ ).reshape(-1)
624
+ sorted_flat_indices = jnp.argsort(overage)[::-1]
625
+ selected_flat_indices = jnp.where(
626
+ overage[sorted_flat_indices] > -jnp.inf,
627
+ sorted_flat_indices,
628
+ -jnp.ones_like(sorted_flat_indices),
629
+ )[: self.max_expiries_per_step]
622
630
 
631
+ def process_expiries():
623
632
  key_local, expiry_key = jax.random.split(key)
624
633
 
625
- def process_one_expiry(carry, i):
626
- obj_state = carry
627
- y = y_indices[i]
628
- x = x_indices[i]
634
+ def body_fn(i, obj_state):
635
+ flat_idx = selected_flat_indices[i]
629
636
 
630
- # Skip if this is a padding index (from fill_value)
631
- is_valid = (y >= 0) & (x >= 0)
632
-
633
- def expire_one():
637
+ def expire_at(obj_state):
638
+ y = flat_idx // self.size[0]
639
+ x = flat_idx % self.size[0]
634
640
  obj_id = current_objects_for_expiry[y, x]
635
- exp_key = jax.random.fold_in(expiry_key, y * self.size[0] + x)
641
+ exp_key = jax.random.fold_in(expiry_key, flat_idx)
636
642
  exp_delay = jax.lax.switch(
637
643
  obj_id, self.expiry_regen_delay_fns, state.time, exp_key
638
644
  )
639
645
  timer_countdown = jax.lax.cond(
640
646
  exp_delay == jnp.iinfo(jnp.int32).max,
641
- lambda: 0, # No timer (permanent removal)
642
- lambda: exp_delay + 1, # Timer countdown
647
+ lambda: 0,
648
+ lambda: exp_delay + 1,
643
649
  )
644
650
 
645
- # Use unified timer placement method
646
- rand_key = jax.random.split(exp_key)[1]
647
- new_obj_state = self._place_timer(
651
+ respawn_random = self.object_random_respawn[obj_id]
652
+ rand_key = jax.random.fold_in(exp_key, 1)
653
+ return self._place_timer(
648
654
  obj_state,
649
655
  y,
650
656
  x,
651
657
  obj_id,
652
658
  timer_countdown,
653
- self.object_random_respawn[obj_id],
659
+ respawn_random,
654
660
  rand_key,
655
661
  )
656
662
 
657
- return new_obj_state
658
-
659
- def no_op():
660
- return obj_state
661
-
662
- return jax.lax.cond(is_valid, expire_one, no_op), None
663
+ return jax.lax.cond(
664
+ flat_idx >= 0,
665
+ expire_at,
666
+ lambda obj_state: obj_state,
667
+ obj_state,
668
+ )
663
669
 
664
- new_object_state, _ = jax.lax.scan(
665
- process_one_expiry,
666
- object_state,
667
- jnp.arange(max_objects),
670
+ new_object_state = jax.lax.fori_loop(
671
+ 0, self.max_expiries_per_step, body_fn, object_state
668
672
  )
669
673
  return key_local, new_object_state
670
674
 
@@ -672,7 +676,7 @@ class ForagaxEnv(environment.Environment):
672
676
  return key, object_state
673
677
 
674
678
  key, object_state = jax.lax.cond(
675
- num_expiring > 0,
679
+ has_expiring,
676
680
  process_expiries,
677
681
  no_expiries,
678
682
  )
foragax/objects.py CHANGED
@@ -241,6 +241,7 @@ class FourierObject(BaseForagaxObject):
241
241
  reward_delay: int = 0,
242
242
  max_reward_delay: Optional[int] = None,
243
243
  regen_delay: Optional[Tuple[int, int]] = None,
244
+ reward_repeat: int = 1,
244
245
  ):
245
246
  if max_reward_delay is None:
246
247
  max_reward_delay = reward_delay
@@ -257,6 +258,7 @@ class FourierObject(BaseForagaxObject):
257
258
  self.base_magnitude = base_magnitude
258
259
  self.reward_delay_val = reward_delay
259
260
  self.regen_delay_range = regen_delay
261
+ self.reward_repeat = reward_repeat
260
262
 
261
263
  def get_state(self, key: jax.Array) -> jax.Array:
262
264
  """Generate random Fourier series parameters.
@@ -321,7 +323,7 @@ class FourierObject(BaseForagaxObject):
321
323
  y_max = params[2]
322
324
 
323
325
  # Normalize time to [0, 2π] using the object's period
324
- t = 2.0 * jnp.pi * (clock % period) / period
326
+ t = 2.0 * jnp.pi * ((clock // self.reward_repeat) % period) / period
325
327
 
326
328
  # Extract interleaved coefficients: [a1, b1, a2, b2, ...]
327
329
  ab_coeffs = params[3:]
@@ -715,6 +717,7 @@ def create_fourier_objects(
715
717
  base_magnitude: float = 1.0,
716
718
  reward_delay: int = 0,
717
719
  regen_delay: Optional[Tuple[int, int]] = None,
720
+ reward_repeat: int = 1,
718
721
  ):
719
722
  """Create HOT and COLD FourierObject instances.
720
723
 
@@ -733,6 +736,7 @@ def create_fourier_objects(
733
736
  color=(0, 0, 0),
734
737
  reward_delay=reward_delay,
735
738
  regen_delay=regen_delay,
739
+ reward_repeat=reward_repeat,
736
740
  )
737
741
 
738
742
  cold = FourierObject(
@@ -742,6 +746,7 @@ def create_fourier_objects(
742
746
  color=(0, 0, 0),
743
747
  reward_delay=reward_delay,
744
748
  regen_delay=regen_delay,
749
+ reward_repeat=reward_repeat,
745
750
  )
746
751
 
747
752
  return hot, cold
foragax/registry.py CHANGED
@@ -128,6 +128,21 @@ ENV_CONFIGS: Dict[str, Dict[str, Any]] = {
128
128
  "deterministic_spawn": True,
129
129
  "dynamic_biomes": True,
130
130
  },
131
+ "ForagaxDiwali-v4": {
132
+ "size": (15, 15),
133
+ "aperture_size": None,
134
+ "objects": None,
135
+ "biomes": (
136
+ # Hot biome
137
+ Biome(start=(0, 2), stop=(15, 6), object_frequencies=(0.5, 0.0)),
138
+ # Cold biome
139
+ Biome(start=(0, 9), stop=(15, 13), object_frequencies=(0.0, 0.5)),
140
+ ),
141
+ "nowrap": False,
142
+ "deterministic_spawn": True,
143
+ "dynamic_biomes": True,
144
+ "biome_consumption_threshold": 1000,
145
+ },
131
146
  "ForagaxTwoBiome-v1": {
132
147
  "size": (15, 15),
133
148
  "aperture_size": None,
@@ -575,6 +590,13 @@ def make(
575
590
  reward_delay=reward_delay,
576
591
  regen_delay=(9, 11),
577
592
  )[:1]
593
+ if env_id == "ForagaxDiwali-v4":
594
+ config["objects"] = create_fourier_objects(
595
+ num_fourier_terms=10,
596
+ reward_delay=reward_delay,
597
+ regen_delay=(9, 11),
598
+ reward_repeat=100,
599
+ )
578
600
 
579
601
  if env_id == "ForagaxSineTwoBiome-v1":
580
602
  biome1_oyster, biome1_deathcap, biome2_oyster, biome2_deathcap = (