continual-foragax 0.35.0__py3-none-any.whl → 0.36.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {continual_foragax-0.35.0.dist-info → continual_foragax-0.36.1.dist-info}/METADATA +1 -1
- {continual_foragax-0.35.0.dist-info → continual_foragax-0.36.1.dist-info}/RECORD +7 -7
- foragax/env.py +40 -36
- foragax/registry.py +15 -0
- {continual_foragax-0.35.0.dist-info → continual_foragax-0.36.1.dist-info}/WHEEL +0 -0
- {continual_foragax-0.35.0.dist-info → continual_foragax-0.36.1.dist-info}/entry_points.txt +0 -0
- {continual_foragax-0.35.0.dist-info → continual_foragax-0.36.1.dist-info}/top_level.txt +0 -0
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
foragax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
2
|
foragax/colors.py,sha256=rqNPiywP4Nvr0POhsGpasRk-nMMTS3DOwFRUgperlUk,2065
|
|
3
|
-
foragax/env.py,sha256=
|
|
3
|
+
foragax/env.py,sha256=7IhLCEosM_IH19sMzRnGMvTNq9tlhVg4PU52hGz9XLE,66531
|
|
4
4
|
foragax/objects.py,sha256=PPuLYjD7em7GL404eSpP6q8TxF8p7JtQ1kIwh7uD_tU,26860
|
|
5
|
-
foragax/registry.py,sha256=
|
|
5
|
+
foragax/registry.py,sha256=y_MEvM0K_E4fzdp4XokzLh1D94UAuCu16wGwJV4wa_A,19289
|
|
6
6
|
foragax/rendering.py,sha256=bms7wvBZTofoR-K-2QD2Ggeed7Viw8uwAEiEpEM3eSo,2768
|
|
7
7
|
foragax/weather.py,sha256=KNAiwuFz8V__6G75vZIWQKPocLzXqxXn-Vt4TbHIpcA,1258
|
|
8
8
|
foragax/data/ECA_non-blended_custom/TG_SOUID100897.txt,sha256=N7URbX6VlCZvCboUogYjMzy1I-0cfNPOn0QTLSHHfQ0,1776751
|
|
@@ -128,8 +128,8 @@ foragax/data/ECA_non-blended_custom/TG_SOUID156887.txt,sha256=juzTPgJoJxfqmZkorL
|
|
|
128
128
|
foragax/data/ECA_non-blended_custom/elements.txt,sha256=OtcUBoDAHxuln79BPKGu0tsQxG_5G2BfAX3Ck130kEA,4507
|
|
129
129
|
foragax/data/ECA_non-blended_custom/metadata.txt,sha256=nudnmOCy5cPJfSXt_IjyX0S5-T7NkCZREICZSimqeqc,48260
|
|
130
130
|
foragax/data/ECA_non-blended_custom/sources.txt,sha256=1j3lSmINAoCMqPqFrHfZJriOz6sTYZNOhXzUwvTLas0,20857
|
|
131
|
-
continual_foragax-0.
|
|
132
|
-
continual_foragax-0.
|
|
133
|
-
continual_foragax-0.
|
|
134
|
-
continual_foragax-0.
|
|
135
|
-
continual_foragax-0.
|
|
131
|
+
continual_foragax-0.36.1.dist-info/METADATA,sha256=SgkncBLbu9_MerET0r6Zto2YSLqfvmF4VE9vYXXh87g,4713
|
|
132
|
+
continual_foragax-0.36.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
133
|
+
continual_foragax-0.36.1.dist-info/entry_points.txt,sha256=Qiu6iE_XudrDO_bVAMeA435h4PO9ourt8huvSHiuMPc,41
|
|
134
|
+
continual_foragax-0.36.1.dist-info/top_level.txt,sha256=-z3SDK6RfLIcLI24n8rdbeFzlVY3hunChzlu-v1Fncs,8
|
|
135
|
+
continual_foragax-0.36.1.dist-info/RECORD,,
|
foragax/env.py
CHANGED
|
@@ -132,6 +132,7 @@ class ForagaxEnv(environment.Environment):
|
|
|
132
132
|
observation_type: str = "object",
|
|
133
133
|
dynamic_biomes: bool = False,
|
|
134
134
|
biome_consumption_threshold: float = 0.9,
|
|
135
|
+
max_expiries_per_step: int = 1,
|
|
135
136
|
):
|
|
136
137
|
super().__init__()
|
|
137
138
|
self._name = name
|
|
@@ -155,6 +156,9 @@ class ForagaxEnv(environment.Environment):
|
|
|
155
156
|
self.teleport_interval = teleport_interval
|
|
156
157
|
self.dynamic_biomes = dynamic_biomes
|
|
157
158
|
self.biome_consumption_threshold = biome_consumption_threshold
|
|
159
|
+
if max_expiries_per_step < 1:
|
|
160
|
+
raise ValueError("max_expiries_per_step must be at least 1")
|
|
161
|
+
self.max_expiries_per_step = max_expiries_per_step
|
|
158
162
|
|
|
159
163
|
objects = (EMPTY,) + objects
|
|
160
164
|
if self.nowrap and not self.full_world:
|
|
@@ -608,63 +612,63 @@ class ForagaxEnv(environment.Environment):
|
|
|
608
612
|
& (current_objects_for_expiry > 0)
|
|
609
613
|
)
|
|
610
614
|
|
|
611
|
-
# Count how many objects actually need to expire
|
|
612
|
-
num_expiring = jnp.sum(should_expire)
|
|
613
|
-
|
|
614
615
|
# Only process expiry if there are actually objects to expire
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
616
|
+
has_expiring = jnp.any(should_expire)
|
|
617
|
+
|
|
618
|
+
# Precompute the first expiring index in flat space so the work inside cond is minimal.
|
|
619
|
+
overage = jnp.where(
|
|
620
|
+
should_expire,
|
|
621
|
+
object_ages - expiry_times,
|
|
622
|
+
-jnp.inf,
|
|
623
|
+
).reshape(-1)
|
|
624
|
+
sorted_flat_indices = jnp.argsort(overage)[::-1]
|
|
625
|
+
selected_flat_indices = jnp.where(
|
|
626
|
+
overage[sorted_flat_indices] > -jnp.inf,
|
|
627
|
+
sorted_flat_indices,
|
|
628
|
+
-jnp.ones_like(sorted_flat_indices),
|
|
629
|
+
)[: self.max_expiries_per_step]
|
|
622
630
|
|
|
631
|
+
def process_expiries():
|
|
623
632
|
key_local, expiry_key = jax.random.split(key)
|
|
624
633
|
|
|
625
|
-
def
|
|
626
|
-
|
|
627
|
-
y = y_indices[i]
|
|
628
|
-
x = x_indices[i]
|
|
634
|
+
def body_fn(i, obj_state):
|
|
635
|
+
flat_idx = selected_flat_indices[i]
|
|
629
636
|
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
def expire_one():
|
|
637
|
+
def expire_at(obj_state):
|
|
638
|
+
y = flat_idx // self.size[0]
|
|
639
|
+
x = flat_idx % self.size[0]
|
|
634
640
|
obj_id = current_objects_for_expiry[y, x]
|
|
635
|
-
exp_key = jax.random.fold_in(expiry_key,
|
|
641
|
+
exp_key = jax.random.fold_in(expiry_key, flat_idx)
|
|
636
642
|
exp_delay = jax.lax.switch(
|
|
637
643
|
obj_id, self.expiry_regen_delay_fns, state.time, exp_key
|
|
638
644
|
)
|
|
639
645
|
timer_countdown = jax.lax.cond(
|
|
640
646
|
exp_delay == jnp.iinfo(jnp.int32).max,
|
|
641
|
-
lambda: 0,
|
|
642
|
-
lambda: exp_delay + 1,
|
|
647
|
+
lambda: 0,
|
|
648
|
+
lambda: exp_delay + 1,
|
|
643
649
|
)
|
|
644
650
|
|
|
645
|
-
|
|
646
|
-
rand_key = jax.random.
|
|
647
|
-
|
|
651
|
+
respawn_random = self.object_random_respawn[obj_id]
|
|
652
|
+
rand_key = jax.random.fold_in(exp_key, 1)
|
|
653
|
+
return self._place_timer(
|
|
648
654
|
obj_state,
|
|
649
655
|
y,
|
|
650
656
|
x,
|
|
651
657
|
obj_id,
|
|
652
658
|
timer_countdown,
|
|
653
|
-
|
|
659
|
+
respawn_random,
|
|
654
660
|
rand_key,
|
|
655
661
|
)
|
|
656
662
|
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
+
return jax.lax.cond(
|
|
664
|
+
flat_idx >= 0,
|
|
665
|
+
expire_at,
|
|
666
|
+
lambda obj_state: obj_state,
|
|
667
|
+
obj_state,
|
|
668
|
+
)
|
|
663
669
|
|
|
664
|
-
new_object_state
|
|
665
|
-
|
|
666
|
-
object_state,
|
|
667
|
-
jnp.arange(max_objects),
|
|
670
|
+
new_object_state = jax.lax.fori_loop(
|
|
671
|
+
0, self.max_expiries_per_step, body_fn, object_state
|
|
668
672
|
)
|
|
669
673
|
return key_local, new_object_state
|
|
670
674
|
|
|
@@ -672,7 +676,7 @@ class ForagaxEnv(environment.Environment):
|
|
|
672
676
|
return key, object_state
|
|
673
677
|
|
|
674
678
|
key, object_state = jax.lax.cond(
|
|
675
|
-
|
|
679
|
+
has_expiring,
|
|
676
680
|
process_expiries,
|
|
677
681
|
no_expiries,
|
|
678
682
|
)
|
foragax/registry.py
CHANGED
|
@@ -119,6 +119,15 @@ ENV_CONFIGS: Dict[str, Dict[str, Any]] = {
|
|
|
119
119
|
"dynamic_biomes": True,
|
|
120
120
|
"biome_consumption_threshold": 200,
|
|
121
121
|
},
|
|
122
|
+
"ForagaxDiwali-v3": {
|
|
123
|
+
"size": (15, 15),
|
|
124
|
+
"aperture_size": None,
|
|
125
|
+
"objects": None,
|
|
126
|
+
"biomes": (Biome(start=(0, 0), stop=(15, 15), object_frequencies=(0.5,)),),
|
|
127
|
+
"nowrap": False,
|
|
128
|
+
"deterministic_spawn": True,
|
|
129
|
+
"dynamic_biomes": True,
|
|
130
|
+
},
|
|
122
131
|
"ForagaxTwoBiome-v1": {
|
|
123
132
|
"size": (15, 15),
|
|
124
133
|
"aperture_size": None,
|
|
@@ -560,6 +569,12 @@ def make(
|
|
|
560
569
|
reward_delay=reward_delay,
|
|
561
570
|
regen_delay=(9, 11),
|
|
562
571
|
)
|
|
572
|
+
if env_id == "ForagaxDiwali-v3":
|
|
573
|
+
config["objects"] = create_fourier_objects(
|
|
574
|
+
num_fourier_terms=10,
|
|
575
|
+
reward_delay=reward_delay,
|
|
576
|
+
regen_delay=(9, 11),
|
|
577
|
+
)[:1]
|
|
563
578
|
|
|
564
579
|
if env_id == "ForagaxSineTwoBiome-v1":
|
|
565
580
|
biome1_oyster, biome1_deathcap, biome2_oyster, biome2_deathcap = (
|
|
File without changes
|
|
File without changes
|
|
File without changes
|