continual-foragax 0.36.0__py3-none-any.whl → 0.37.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {continual_foragax-0.36.0.dist-info → continual_foragax-0.37.0.dist-info}/METADATA +1 -1
- {continual_foragax-0.36.0.dist-info → continual_foragax-0.37.0.dist-info}/RECORD +8 -8
- foragax/env.py +40 -36
- foragax/objects.py +6 -1
- foragax/registry.py +22 -0
- {continual_foragax-0.36.0.dist-info → continual_foragax-0.37.0.dist-info}/WHEEL +0 -0
- {continual_foragax-0.36.0.dist-info → continual_foragax-0.37.0.dist-info}/entry_points.txt +0 -0
- {continual_foragax-0.36.0.dist-info → continual_foragax-0.37.0.dist-info}/top_level.txt +0 -0
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
foragax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
2
|
foragax/colors.py,sha256=rqNPiywP4Nvr0POhsGpasRk-nMMTS3DOwFRUgperlUk,2065
|
|
3
|
-
foragax/env.py,sha256=
|
|
4
|
-
foragax/objects.py,sha256=
|
|
5
|
-
foragax/registry.py,sha256=
|
|
3
|
+
foragax/env.py,sha256=7IhLCEosM_IH19sMzRnGMvTNq9tlhVg4PU52hGz9XLE,66531
|
|
4
|
+
foragax/objects.py,sha256=aVc7lD3CTyRP9wm_Vs93qo4l_B1kbiYGKPtkd_SVXjs,27061
|
|
5
|
+
foragax/registry.py,sha256=dTdRwNNW8jfeDxioYCx4-GtL-J8t1nGwt7gDwDvN5TY,20016
|
|
6
6
|
foragax/rendering.py,sha256=bms7wvBZTofoR-K-2QD2Ggeed7Viw8uwAEiEpEM3eSo,2768
|
|
7
7
|
foragax/weather.py,sha256=KNAiwuFz8V__6G75vZIWQKPocLzXqxXn-Vt4TbHIpcA,1258
|
|
8
8
|
foragax/data/ECA_non-blended_custom/TG_SOUID100897.txt,sha256=N7URbX6VlCZvCboUogYjMzy1I-0cfNPOn0QTLSHHfQ0,1776751
|
|
@@ -128,8 +128,8 @@ foragax/data/ECA_non-blended_custom/TG_SOUID156887.txt,sha256=juzTPgJoJxfqmZkorL
|
|
|
128
128
|
foragax/data/ECA_non-blended_custom/elements.txt,sha256=OtcUBoDAHxuln79BPKGu0tsQxG_5G2BfAX3Ck130kEA,4507
|
|
129
129
|
foragax/data/ECA_non-blended_custom/metadata.txt,sha256=nudnmOCy5cPJfSXt_IjyX0S5-T7NkCZREICZSimqeqc,48260
|
|
130
130
|
foragax/data/ECA_non-blended_custom/sources.txt,sha256=1j3lSmINAoCMqPqFrHfZJriOz6sTYZNOhXzUwvTLas0,20857
|
|
131
|
-
continual_foragax-0.
|
|
132
|
-
continual_foragax-0.
|
|
133
|
-
continual_foragax-0.
|
|
134
|
-
continual_foragax-0.
|
|
135
|
-
continual_foragax-0.
|
|
131
|
+
continual_foragax-0.37.0.dist-info/METADATA,sha256=fWAIh_Yq86ibXEtNLXKrqjcgxtoMPS0Em_q_UU8XZ8Q,4713
|
|
132
|
+
continual_foragax-0.37.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
133
|
+
continual_foragax-0.37.0.dist-info/entry_points.txt,sha256=Qiu6iE_XudrDO_bVAMeA435h4PO9ourt8huvSHiuMPc,41
|
|
134
|
+
continual_foragax-0.37.0.dist-info/top_level.txt,sha256=-z3SDK6RfLIcLI24n8rdbeFzlVY3hunChzlu-v1Fncs,8
|
|
135
|
+
continual_foragax-0.37.0.dist-info/RECORD,,
|
foragax/env.py
CHANGED
|
@@ -132,6 +132,7 @@ class ForagaxEnv(environment.Environment):
|
|
|
132
132
|
observation_type: str = "object",
|
|
133
133
|
dynamic_biomes: bool = False,
|
|
134
134
|
biome_consumption_threshold: float = 0.9,
|
|
135
|
+
max_expiries_per_step: int = 1,
|
|
135
136
|
):
|
|
136
137
|
super().__init__()
|
|
137
138
|
self._name = name
|
|
@@ -155,6 +156,9 @@ class ForagaxEnv(environment.Environment):
|
|
|
155
156
|
self.teleport_interval = teleport_interval
|
|
156
157
|
self.dynamic_biomes = dynamic_biomes
|
|
157
158
|
self.biome_consumption_threshold = biome_consumption_threshold
|
|
159
|
+
if max_expiries_per_step < 1:
|
|
160
|
+
raise ValueError("max_expiries_per_step must be at least 1")
|
|
161
|
+
self.max_expiries_per_step = max_expiries_per_step
|
|
158
162
|
|
|
159
163
|
objects = (EMPTY,) + objects
|
|
160
164
|
if self.nowrap and not self.full_world:
|
|
@@ -608,63 +612,63 @@ class ForagaxEnv(environment.Environment):
|
|
|
608
612
|
& (current_objects_for_expiry > 0)
|
|
609
613
|
)
|
|
610
614
|
|
|
611
|
-
# Count how many objects actually need to expire
|
|
612
|
-
num_expiring = jnp.sum(should_expire)
|
|
613
|
-
|
|
614
615
|
# Only process expiry if there are actually objects to expire
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
616
|
+
has_expiring = jnp.any(should_expire)
|
|
617
|
+
|
|
618
|
+
# Precompute the first expiring index in flat space so the work inside cond is minimal.
|
|
619
|
+
overage = jnp.where(
|
|
620
|
+
should_expire,
|
|
621
|
+
object_ages - expiry_times,
|
|
622
|
+
-jnp.inf,
|
|
623
|
+
).reshape(-1)
|
|
624
|
+
sorted_flat_indices = jnp.argsort(overage)[::-1]
|
|
625
|
+
selected_flat_indices = jnp.where(
|
|
626
|
+
overage[sorted_flat_indices] > -jnp.inf,
|
|
627
|
+
sorted_flat_indices,
|
|
628
|
+
-jnp.ones_like(sorted_flat_indices),
|
|
629
|
+
)[: self.max_expiries_per_step]
|
|
622
630
|
|
|
631
|
+
def process_expiries():
|
|
623
632
|
key_local, expiry_key = jax.random.split(key)
|
|
624
633
|
|
|
625
|
-
def
|
|
626
|
-
|
|
627
|
-
y = y_indices[i]
|
|
628
|
-
x = x_indices[i]
|
|
634
|
+
def body_fn(i, obj_state):
|
|
635
|
+
flat_idx = selected_flat_indices[i]
|
|
629
636
|
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
def expire_one():
|
|
637
|
+
def expire_at(obj_state):
|
|
638
|
+
y = flat_idx // self.size[0]
|
|
639
|
+
x = flat_idx % self.size[0]
|
|
634
640
|
obj_id = current_objects_for_expiry[y, x]
|
|
635
|
-
exp_key = jax.random.fold_in(expiry_key,
|
|
641
|
+
exp_key = jax.random.fold_in(expiry_key, flat_idx)
|
|
636
642
|
exp_delay = jax.lax.switch(
|
|
637
643
|
obj_id, self.expiry_regen_delay_fns, state.time, exp_key
|
|
638
644
|
)
|
|
639
645
|
timer_countdown = jax.lax.cond(
|
|
640
646
|
exp_delay == jnp.iinfo(jnp.int32).max,
|
|
641
|
-
lambda: 0,
|
|
642
|
-
lambda: exp_delay + 1,
|
|
647
|
+
lambda: 0,
|
|
648
|
+
lambda: exp_delay + 1,
|
|
643
649
|
)
|
|
644
650
|
|
|
645
|
-
|
|
646
|
-
rand_key = jax.random.
|
|
647
|
-
|
|
651
|
+
respawn_random = self.object_random_respawn[obj_id]
|
|
652
|
+
rand_key = jax.random.fold_in(exp_key, 1)
|
|
653
|
+
return self._place_timer(
|
|
648
654
|
obj_state,
|
|
649
655
|
y,
|
|
650
656
|
x,
|
|
651
657
|
obj_id,
|
|
652
658
|
timer_countdown,
|
|
653
|
-
|
|
659
|
+
respawn_random,
|
|
654
660
|
rand_key,
|
|
655
661
|
)
|
|
656
662
|
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
+
return jax.lax.cond(
|
|
664
|
+
flat_idx >= 0,
|
|
665
|
+
expire_at,
|
|
666
|
+
lambda obj_state: obj_state,
|
|
667
|
+
obj_state,
|
|
668
|
+
)
|
|
663
669
|
|
|
664
|
-
new_object_state
|
|
665
|
-
|
|
666
|
-
object_state,
|
|
667
|
-
jnp.arange(max_objects),
|
|
670
|
+
new_object_state = jax.lax.fori_loop(
|
|
671
|
+
0, self.max_expiries_per_step, body_fn, object_state
|
|
668
672
|
)
|
|
669
673
|
return key_local, new_object_state
|
|
670
674
|
|
|
@@ -672,7 +676,7 @@ class ForagaxEnv(environment.Environment):
|
|
|
672
676
|
return key, object_state
|
|
673
677
|
|
|
674
678
|
key, object_state = jax.lax.cond(
|
|
675
|
-
|
|
679
|
+
has_expiring,
|
|
676
680
|
process_expiries,
|
|
677
681
|
no_expiries,
|
|
678
682
|
)
|
foragax/objects.py
CHANGED
|
@@ -241,6 +241,7 @@ class FourierObject(BaseForagaxObject):
|
|
|
241
241
|
reward_delay: int = 0,
|
|
242
242
|
max_reward_delay: Optional[int] = None,
|
|
243
243
|
regen_delay: Optional[Tuple[int, int]] = None,
|
|
244
|
+
reward_repeat: int = 1,
|
|
244
245
|
):
|
|
245
246
|
if max_reward_delay is None:
|
|
246
247
|
max_reward_delay = reward_delay
|
|
@@ -257,6 +258,7 @@ class FourierObject(BaseForagaxObject):
|
|
|
257
258
|
self.base_magnitude = base_magnitude
|
|
258
259
|
self.reward_delay_val = reward_delay
|
|
259
260
|
self.regen_delay_range = regen_delay
|
|
261
|
+
self.reward_repeat = reward_repeat
|
|
260
262
|
|
|
261
263
|
def get_state(self, key: jax.Array) -> jax.Array:
|
|
262
264
|
"""Generate random Fourier series parameters.
|
|
@@ -321,7 +323,7 @@ class FourierObject(BaseForagaxObject):
|
|
|
321
323
|
y_max = params[2]
|
|
322
324
|
|
|
323
325
|
# Normalize time to [0, 2π] using the object's period
|
|
324
|
-
t = 2.0 * jnp.pi * (clock % period) / period
|
|
326
|
+
t = 2.0 * jnp.pi * ((clock // self.reward_repeat) % period) / period
|
|
325
327
|
|
|
326
328
|
# Extract interleaved coefficients: [a1, b1, a2, b2, ...]
|
|
327
329
|
ab_coeffs = params[3:]
|
|
@@ -715,6 +717,7 @@ def create_fourier_objects(
|
|
|
715
717
|
base_magnitude: float = 1.0,
|
|
716
718
|
reward_delay: int = 0,
|
|
717
719
|
regen_delay: Optional[Tuple[int, int]] = None,
|
|
720
|
+
reward_repeat: int = 1,
|
|
718
721
|
):
|
|
719
722
|
"""Create HOT and COLD FourierObject instances.
|
|
720
723
|
|
|
@@ -733,6 +736,7 @@ def create_fourier_objects(
|
|
|
733
736
|
color=(0, 0, 0),
|
|
734
737
|
reward_delay=reward_delay,
|
|
735
738
|
regen_delay=regen_delay,
|
|
739
|
+
reward_repeat=reward_repeat,
|
|
736
740
|
)
|
|
737
741
|
|
|
738
742
|
cold = FourierObject(
|
|
@@ -742,6 +746,7 @@ def create_fourier_objects(
|
|
|
742
746
|
color=(0, 0, 0),
|
|
743
747
|
reward_delay=reward_delay,
|
|
744
748
|
regen_delay=regen_delay,
|
|
749
|
+
reward_repeat=reward_repeat,
|
|
745
750
|
)
|
|
746
751
|
|
|
747
752
|
return hot, cold
|
foragax/registry.py
CHANGED
|
@@ -128,6 +128,21 @@ ENV_CONFIGS: Dict[str, Dict[str, Any]] = {
|
|
|
128
128
|
"deterministic_spawn": True,
|
|
129
129
|
"dynamic_biomes": True,
|
|
130
130
|
},
|
|
131
|
+
"ForagaxDiwali-v4": {
|
|
132
|
+
"size": (15, 15),
|
|
133
|
+
"aperture_size": None,
|
|
134
|
+
"objects": None,
|
|
135
|
+
"biomes": (
|
|
136
|
+
# Hot biome
|
|
137
|
+
Biome(start=(0, 2), stop=(15, 6), object_frequencies=(0.5, 0.0)),
|
|
138
|
+
# Cold biome
|
|
139
|
+
Biome(start=(0, 9), stop=(15, 13), object_frequencies=(0.0, 0.5)),
|
|
140
|
+
),
|
|
141
|
+
"nowrap": False,
|
|
142
|
+
"deterministic_spawn": True,
|
|
143
|
+
"dynamic_biomes": True,
|
|
144
|
+
"biome_consumption_threshold": 1000,
|
|
145
|
+
},
|
|
131
146
|
"ForagaxTwoBiome-v1": {
|
|
132
147
|
"size": (15, 15),
|
|
133
148
|
"aperture_size": None,
|
|
@@ -575,6 +590,13 @@ def make(
|
|
|
575
590
|
reward_delay=reward_delay,
|
|
576
591
|
regen_delay=(9, 11),
|
|
577
592
|
)[:1]
|
|
593
|
+
if env_id == "ForagaxDiwali-v4":
|
|
594
|
+
config["objects"] = create_fourier_objects(
|
|
595
|
+
num_fourier_terms=10,
|
|
596
|
+
reward_delay=reward_delay,
|
|
597
|
+
regen_delay=(9, 11),
|
|
598
|
+
reward_repeat=100,
|
|
599
|
+
)
|
|
578
600
|
|
|
579
601
|
if env_id == "ForagaxSineTwoBiome-v1":
|
|
580
602
|
biome1_oyster, biome1_deathcap, biome2_oyster, biome2_deathcap = (
|
|
File without changes
|
|
File without changes
|
|
File without changes
|