continual-foragax 0.35.0__py3-none-any.whl → 0.36.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: continual-foragax
3
- Version: 0.35.0
3
+ Version: 0.36.1
4
4
  Summary: A continual reinforcement learning benchmark
5
5
  Author-email: Steven Tang <stang5@ualberta.ca>
6
6
  Requires-Python: >=3.8
@@ -1,8 +1,8 @@
1
1
  foragax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  foragax/colors.py,sha256=rqNPiywP4Nvr0POhsGpasRk-nMMTS3DOwFRUgperlUk,2065
3
- foragax/env.py,sha256=K3noPwdYmQlnXVjslqVzX_FIB-CnOh37mWFArQXnf_Y,66324
3
+ foragax/env.py,sha256=7IhLCEosM_IH19sMzRnGMvTNq9tlhVg4PU52hGz9XLE,66531
4
4
  foragax/objects.py,sha256=PPuLYjD7em7GL404eSpP6q8TxF8p7JtQ1kIwh7uD_tU,26860
5
- foragax/registry.py,sha256=Ph_Z3O5GpIjrgvbKL-8Iq-Kc6MqfZIsF9KDDDzm7N3o,18787
5
+ foragax/registry.py,sha256=y_MEvM0K_E4fzdp4XokzLh1D94UAuCu16wGwJV4wa_A,19289
6
6
  foragax/rendering.py,sha256=bms7wvBZTofoR-K-2QD2Ggeed7Viw8uwAEiEpEM3eSo,2768
7
7
  foragax/weather.py,sha256=KNAiwuFz8V__6G75vZIWQKPocLzXqxXn-Vt4TbHIpcA,1258
8
8
  foragax/data/ECA_non-blended_custom/TG_SOUID100897.txt,sha256=N7URbX6VlCZvCboUogYjMzy1I-0cfNPOn0QTLSHHfQ0,1776751
@@ -128,8 +128,8 @@ foragax/data/ECA_non-blended_custom/TG_SOUID156887.txt,sha256=juzTPgJoJxfqmZkorL
128
128
  foragax/data/ECA_non-blended_custom/elements.txt,sha256=OtcUBoDAHxuln79BPKGu0tsQxG_5G2BfAX3Ck130kEA,4507
129
129
  foragax/data/ECA_non-blended_custom/metadata.txt,sha256=nudnmOCy5cPJfSXt_IjyX0S5-T7NkCZREICZSimqeqc,48260
130
130
  foragax/data/ECA_non-blended_custom/sources.txt,sha256=1j3lSmINAoCMqPqFrHfZJriOz6sTYZNOhXzUwvTLas0,20857
131
- continual_foragax-0.35.0.dist-info/METADATA,sha256=pZ1uSNXsaYkaFaz7xvO9X7bee8Jr9RjT7nbwW3tx5ps,4713
132
- continual_foragax-0.35.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
133
- continual_foragax-0.35.0.dist-info/entry_points.txt,sha256=Qiu6iE_XudrDO_bVAMeA435h4PO9ourt8huvSHiuMPc,41
134
- continual_foragax-0.35.0.dist-info/top_level.txt,sha256=-z3SDK6RfLIcLI24n8rdbeFzlVY3hunChzlu-v1Fncs,8
135
- continual_foragax-0.35.0.dist-info/RECORD,,
131
+ continual_foragax-0.36.1.dist-info/METADATA,sha256=SgkncBLbu9_MerET0r6Zto2YSLqfvmF4VE9vYXXh87g,4713
132
+ continual_foragax-0.36.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
133
+ continual_foragax-0.36.1.dist-info/entry_points.txt,sha256=Qiu6iE_XudrDO_bVAMeA435h4PO9ourt8huvSHiuMPc,41
134
+ continual_foragax-0.36.1.dist-info/top_level.txt,sha256=-z3SDK6RfLIcLI24n8rdbeFzlVY3hunChzlu-v1Fncs,8
135
+ continual_foragax-0.36.1.dist-info/RECORD,,
foragax/env.py CHANGED
@@ -132,6 +132,7 @@ class ForagaxEnv(environment.Environment):
132
132
  observation_type: str = "object",
133
133
  dynamic_biomes: bool = False,
134
134
  biome_consumption_threshold: float = 0.9,
135
+ max_expiries_per_step: int = 1,
135
136
  ):
136
137
  super().__init__()
137
138
  self._name = name
@@ -155,6 +156,9 @@ class ForagaxEnv(environment.Environment):
155
156
  self.teleport_interval = teleport_interval
156
157
  self.dynamic_biomes = dynamic_biomes
157
158
  self.biome_consumption_threshold = biome_consumption_threshold
159
+ if max_expiries_per_step < 1:
160
+ raise ValueError("max_expiries_per_step must be at least 1")
161
+ self.max_expiries_per_step = max_expiries_per_step
158
162
 
159
163
  objects = (EMPTY,) + objects
160
164
  if self.nowrap and not self.full_world:
@@ -608,63 +612,63 @@ class ForagaxEnv(environment.Environment):
608
612
  & (current_objects_for_expiry > 0)
609
613
  )
610
614
 
611
- # Count how many objects actually need to expire
612
- num_expiring = jnp.sum(should_expire)
613
-
614
615
  # Only process expiry if there are actually objects to expire
615
- def process_expiries():
616
- # Get positions of objects that should expire
617
- # Use nonzero with fixed size to maintain JIT compatibility
618
- max_objects = self.size[0] * self.size[1]
619
- y_indices, x_indices = jnp.nonzero(
620
- should_expire, size=max_objects, fill_value=-1
621
- )
616
+ has_expiring = jnp.any(should_expire)
617
+
618
+ # Precompute the first expiring index in flat space so the work inside cond is minimal.
619
+ overage = jnp.where(
620
+ should_expire,
621
+ object_ages - expiry_times,
622
+ -jnp.inf,
623
+ ).reshape(-1)
624
+ sorted_flat_indices = jnp.argsort(overage)[::-1]
625
+ selected_flat_indices = jnp.where(
626
+ overage[sorted_flat_indices] > -jnp.inf,
627
+ sorted_flat_indices,
628
+ -jnp.ones_like(sorted_flat_indices),
629
+ )[: self.max_expiries_per_step]
622
630
 
631
+ def process_expiries():
623
632
  key_local, expiry_key = jax.random.split(key)
624
633
 
625
- def process_one_expiry(carry, i):
626
- obj_state = carry
627
- y = y_indices[i]
628
- x = x_indices[i]
634
+ def body_fn(i, obj_state):
635
+ flat_idx = selected_flat_indices[i]
629
636
 
630
- # Skip if this is a padding index (from fill_value)
631
- is_valid = (y >= 0) & (x >= 0)
632
-
633
- def expire_one():
637
+ def expire_at(obj_state):
638
+ y = flat_idx // self.size[0]
639
+ x = flat_idx % self.size[0]
634
640
  obj_id = current_objects_for_expiry[y, x]
635
- exp_key = jax.random.fold_in(expiry_key, y * self.size[0] + x)
641
+ exp_key = jax.random.fold_in(expiry_key, flat_idx)
636
642
  exp_delay = jax.lax.switch(
637
643
  obj_id, self.expiry_regen_delay_fns, state.time, exp_key
638
644
  )
639
645
  timer_countdown = jax.lax.cond(
640
646
  exp_delay == jnp.iinfo(jnp.int32).max,
641
- lambda: 0, # No timer (permanent removal)
642
- lambda: exp_delay + 1, # Timer countdown
647
+ lambda: 0,
648
+ lambda: exp_delay + 1,
643
649
  )
644
650
 
645
- # Use unified timer placement method
646
- rand_key = jax.random.split(exp_key)[1]
647
- new_obj_state = self._place_timer(
651
+ respawn_random = self.object_random_respawn[obj_id]
652
+ rand_key = jax.random.fold_in(exp_key, 1)
653
+ return self._place_timer(
648
654
  obj_state,
649
655
  y,
650
656
  x,
651
657
  obj_id,
652
658
  timer_countdown,
653
- self.object_random_respawn[obj_id],
659
+ respawn_random,
654
660
  rand_key,
655
661
  )
656
662
 
657
- return new_obj_state
658
-
659
- def no_op():
660
- return obj_state
661
-
662
- return jax.lax.cond(is_valid, expire_one, no_op), None
663
+ return jax.lax.cond(
664
+ flat_idx >= 0,
665
+ expire_at,
666
+ lambda obj_state: obj_state,
667
+ obj_state,
668
+ )
663
669
 
664
- new_object_state, _ = jax.lax.scan(
665
- process_one_expiry,
666
- object_state,
667
- jnp.arange(max_objects),
670
+ new_object_state = jax.lax.fori_loop(
671
+ 0, self.max_expiries_per_step, body_fn, object_state
668
672
  )
669
673
  return key_local, new_object_state
670
674
 
@@ -672,7 +676,7 @@ class ForagaxEnv(environment.Environment):
672
676
  return key, object_state
673
677
 
674
678
  key, object_state = jax.lax.cond(
675
- num_expiring > 0,
679
+ has_expiring,
676
680
  process_expiries,
677
681
  no_expiries,
678
682
  )
foragax/registry.py CHANGED
@@ -119,6 +119,15 @@ ENV_CONFIGS: Dict[str, Dict[str, Any]] = {
119
119
  "dynamic_biomes": True,
120
120
  "biome_consumption_threshold": 200,
121
121
  },
122
+ "ForagaxDiwali-v3": {
123
+ "size": (15, 15),
124
+ "aperture_size": None,
125
+ "objects": None,
126
+ "biomes": (Biome(start=(0, 0), stop=(15, 15), object_frequencies=(0.5,)),),
127
+ "nowrap": False,
128
+ "deterministic_spawn": True,
129
+ "dynamic_biomes": True,
130
+ },
122
131
  "ForagaxTwoBiome-v1": {
123
132
  "size": (15, 15),
124
133
  "aperture_size": None,
@@ -560,6 +569,12 @@ def make(
560
569
  reward_delay=reward_delay,
561
570
  regen_delay=(9, 11),
562
571
  )
572
+ if env_id == "ForagaxDiwali-v3":
573
+ config["objects"] = create_fourier_objects(
574
+ num_fourier_terms=10,
575
+ reward_delay=reward_delay,
576
+ regen_delay=(9, 11),
577
+ )[:1]
563
578
 
564
579
  if env_id == "ForagaxSineTwoBiome-v1":
565
580
  biome1_oyster, biome1_deathcap, biome2_oyster, biome2_deathcap = (