continual-foragax 0.38.0__py3-none-any.whl → 0.40.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {continual_foragax-0.38.0.dist-info → continual_foragax-0.40.0.dist-info}/METADATA +1 -1
- {continual_foragax-0.38.0.dist-info → continual_foragax-0.40.0.dist-info}/RECORD +7 -7
- foragax/env.py +46 -39
- foragax/registry.py +2 -5
- {continual_foragax-0.38.0.dist-info → continual_foragax-0.40.0.dist-info}/WHEEL +0 -0
- {continual_foragax-0.38.0.dist-info → continual_foragax-0.40.0.dist-info}/entry_points.txt +0 -0
- {continual_foragax-0.38.0.dist-info → continual_foragax-0.40.0.dist-info}/top_level.txt +0 -0
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
foragax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
2
|
foragax/colors.py,sha256=rqNPiywP4Nvr0POhsGpasRk-nMMTS3DOwFRUgperlUk,2065
|
|
3
|
-
foragax/env.py,sha256
|
|
3
|
+
foragax/env.py,sha256=-lBBh70JRInOFR6mmU8uyA9nzgIIPapIdY20y4-vwQU,67848
|
|
4
4
|
foragax/objects.py,sha256=aVc7lD3CTyRP9wm_Vs93qo4l_B1kbiYGKPtkd_SVXjs,27061
|
|
5
|
-
foragax/registry.py,sha256=
|
|
5
|
+
foragax/registry.py,sha256=2qi7Dq96RZVyBQwAnejdBztQWJOQHNAw26UUPBTzIEY,20551
|
|
6
6
|
foragax/rendering.py,sha256=bms7wvBZTofoR-K-2QD2Ggeed7Viw8uwAEiEpEM3eSo,2768
|
|
7
7
|
foragax/weather.py,sha256=KNAiwuFz8V__6G75vZIWQKPocLzXqxXn-Vt4TbHIpcA,1258
|
|
8
8
|
foragax/data/ECA_non-blended_custom/TG_SOUID100897.txt,sha256=N7URbX6VlCZvCboUogYjMzy1I-0cfNPOn0QTLSHHfQ0,1776751
|
|
@@ -128,8 +128,8 @@ foragax/data/ECA_non-blended_custom/TG_SOUID156887.txt,sha256=juzTPgJoJxfqmZkorL
|
|
|
128
128
|
foragax/data/ECA_non-blended_custom/elements.txt,sha256=OtcUBoDAHxuln79BPKGu0tsQxG_5G2BfAX3Ck130kEA,4507
|
|
129
129
|
foragax/data/ECA_non-blended_custom/metadata.txt,sha256=nudnmOCy5cPJfSXt_IjyX0S5-T7NkCZREICZSimqeqc,48260
|
|
130
130
|
foragax/data/ECA_non-blended_custom/sources.txt,sha256=1j3lSmINAoCMqPqFrHfZJriOz6sTYZNOhXzUwvTLas0,20857
|
|
131
|
-
continual_foragax-0.
|
|
132
|
-
continual_foragax-0.
|
|
133
|
-
continual_foragax-0.
|
|
134
|
-
continual_foragax-0.
|
|
135
|
-
continual_foragax-0.
|
|
131
|
+
continual_foragax-0.40.0.dist-info/METADATA,sha256=DN8INomefM2v426Q75KtXL-VqD_fsFKAjQMaphc9Rl4,4713
|
|
132
|
+
continual_foragax-0.40.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
133
|
+
continual_foragax-0.40.0.dist-info/entry_points.txt,sha256=Qiu6iE_XudrDO_bVAMeA435h4PO9ourt8huvSHiuMPc,41
|
|
134
|
+
continual_foragax-0.40.0.dist-info/top_level.txt,sha256=-z3SDK6RfLIcLI24n8rdbeFzlVY3hunChzlu-v1Fncs,8
|
|
135
|
+
continual_foragax-0.40.0.dist-info/RECORD,,
|
foragax/env.py
CHANGED
|
@@ -6,7 +6,7 @@ Source: https://github.com/andnp/Forager
|
|
|
6
6
|
from dataclasses import dataclass
|
|
7
7
|
from enum import IntEnum
|
|
8
8
|
from functools import partial
|
|
9
|
-
from typing import Any, Dict,
|
|
9
|
+
from typing import Any, Dict, Tuple, Union
|
|
10
10
|
|
|
11
11
|
import jax
|
|
12
12
|
import jax.numpy as jnp
|
|
@@ -128,7 +128,6 @@ class ForagaxEnv(environment.Environment):
|
|
|
128
128
|
biomes: Tuple[Biome, ...] = (Biome(object_frequencies=()),),
|
|
129
129
|
nowrap: bool = False,
|
|
130
130
|
deterministic_spawn: bool = False,
|
|
131
|
-
teleport_interval: Optional[int] = None,
|
|
132
131
|
observation_type: str = "object",
|
|
133
132
|
dynamic_biomes: bool = False,
|
|
134
133
|
biome_consumption_threshold: float = 0.9,
|
|
@@ -154,7 +153,6 @@ class ForagaxEnv(environment.Environment):
|
|
|
154
153
|
self.observation_type = observation_type
|
|
155
154
|
self.nowrap = nowrap
|
|
156
155
|
self.deterministic_spawn = deterministic_spawn
|
|
157
|
-
self.teleport_interval = teleport_interval
|
|
158
156
|
self.dynamic_biomes = dynamic_biomes
|
|
159
157
|
self.biome_consumption_threshold = biome_consumption_threshold
|
|
160
158
|
self.dynamic_biome_spawn_empty = dynamic_biome_spawn_empty
|
|
@@ -391,23 +389,6 @@ class ForagaxEnv(environment.Environment):
|
|
|
391
389
|
is_blocking = self.object_blocking[obj_at_new_pos]
|
|
392
390
|
pos = jax.lax.select(is_blocking, state.pos, new_pos)
|
|
393
391
|
|
|
394
|
-
# Check for automatic teleport
|
|
395
|
-
if self.teleport_interval is not None:
|
|
396
|
-
should_teleport = jnp.mod(state.time + 1, self.teleport_interval) == 0
|
|
397
|
-
else:
|
|
398
|
-
should_teleport = False
|
|
399
|
-
|
|
400
|
-
def teleport_fn():
|
|
401
|
-
# Calculate squared distances from current position to each biome center
|
|
402
|
-
diffs = self.biome_centers_jax - pos
|
|
403
|
-
distances = jnp.sum(diffs**2, axis=1)
|
|
404
|
-
# Find the index of the furthest biome center
|
|
405
|
-
furthest_idx = jnp.argmax(distances)
|
|
406
|
-
new_pos = self.biome_centers_jax[furthest_idx]
|
|
407
|
-
return new_pos
|
|
408
|
-
|
|
409
|
-
pos = jax.lax.cond(should_teleport, teleport_fn, lambda: pos)
|
|
410
|
-
|
|
411
392
|
# 2. HANDLE COLLISIONS AND REWARDS
|
|
412
393
|
obj_at_pos = current_objects[pos[1], pos[0]]
|
|
413
394
|
is_collectable = self.object_collectable[obj_at_pos]
|
|
@@ -774,20 +755,53 @@ class ForagaxEnv(environment.Environment):
|
|
|
774
755
|
biome_mask = self.biome_masks_array[i]
|
|
775
756
|
new_gen_value = new_biome_state.generation[i]
|
|
776
757
|
|
|
777
|
-
#
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
758
|
+
# Update mask: biome area AND needs respawn
|
|
759
|
+
should_update = biome_mask & should_respawn[i][..., None]
|
|
760
|
+
|
|
761
|
+
# 1. Merge: Overwrite with new objects if present, otherwise keep existing
|
|
762
|
+
# If the new spawn has an object, we take it. If it's empty, we keep whatever was there.
|
|
763
|
+
new_spawn_valid = all_new_objects[i] > 0
|
|
781
764
|
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
765
|
+
merged_objs = jnp.where(new_spawn_valid, all_new_objects[i], new_obj_id)
|
|
766
|
+
merged_colors = jnp.where(
|
|
767
|
+
new_spawn_valid[..., None], all_new_colors[i], new_color
|
|
785
768
|
)
|
|
786
|
-
|
|
787
|
-
|
|
769
|
+
merged_params = jnp.where(
|
|
770
|
+
new_spawn_valid[..., None], all_new_params[i], new_params
|
|
788
771
|
)
|
|
789
|
-
|
|
790
|
-
|
|
772
|
+
|
|
773
|
+
# For generation/spawn time, update only if we took the NEW object
|
|
774
|
+
merged_gen = jnp.where(new_spawn_valid, new_gen_value, new_gen)
|
|
775
|
+
merged_spawn = jnp.where(new_spawn_valid, current_time, new_spawn)
|
|
776
|
+
|
|
777
|
+
# 2. Apply dropout to the MERGED result (only where we are allowed to update)
|
|
778
|
+
if self.dynamic_biome_spawn_empty > 0:
|
|
779
|
+
key, dropout_key = jax.random.split(key)
|
|
780
|
+
# Dropout applies only to cells that are both in the biome AND need updating
|
|
781
|
+
keep_mask = jax.random.bernoulli(
|
|
782
|
+
dropout_key, 1.0 - self.dynamic_biome_spawn_empty, merged_objs.shape
|
|
783
|
+
)
|
|
784
|
+
# Only apply dropout where should_update is true; elsewhere, keep merged_objs
|
|
785
|
+
dropout_mask = should_update & keep_mask
|
|
786
|
+
# Apply dropout only to the merged result and associated metadata
|
|
787
|
+
final_objs = jnp.where(dropout_mask, merged_objs, 0)
|
|
788
|
+
final_colors = jnp.where(dropout_mask[..., None], merged_colors, 0)
|
|
789
|
+
final_params = jnp.where(dropout_mask[..., None], merged_params, 0)
|
|
790
|
+
final_gen = jnp.where(dropout_mask, merged_gen, 0)
|
|
791
|
+
final_spawn = jnp.where(dropout_mask, merged_spawn, 0)
|
|
792
|
+
else:
|
|
793
|
+
final_objs = merged_objs
|
|
794
|
+
final_colors = merged_colors
|
|
795
|
+
final_params = merged_params
|
|
796
|
+
final_gen = merged_gen
|
|
797
|
+
final_spawn = merged_spawn
|
|
798
|
+
|
|
799
|
+
# 3. Write back: Only update where should_update is true
|
|
800
|
+
new_obj_id = jnp.where(should_update, final_objs, new_obj_id)
|
|
801
|
+
new_color = jnp.where(should_update[..., None], final_colors, new_color)
|
|
802
|
+
new_params = jnp.where(should_update[..., None], final_params, new_params)
|
|
803
|
+
new_gen = jnp.where(should_update, final_gen, new_gen)
|
|
804
|
+
new_spawn = jnp.where(should_update, final_spawn, new_spawn)
|
|
791
805
|
|
|
792
806
|
# Clear timers in respawning biomes
|
|
793
807
|
new_respawn_timer = object_state.respawn_timer
|
|
@@ -890,7 +904,7 @@ class ForagaxEnv(environment.Environment):
|
|
|
890
904
|
biome_freqs = self.biome_object_frequencies[biome_idx]
|
|
891
905
|
biome_mask = self.biome_masks_array[biome_idx]
|
|
892
906
|
|
|
893
|
-
key, spawn_key, color_key, params_key
|
|
907
|
+
key, spawn_key, color_key, params_key = jax.random.split(key, 4)
|
|
894
908
|
|
|
895
909
|
# Generate object IDs using deterministic or random spawn
|
|
896
910
|
if deterministic:
|
|
@@ -929,13 +943,6 @@ class ForagaxEnv(environment.Environment):
|
|
|
929
943
|
jnp.searchsorted(cumulative_freqs, grid_rand, side="right") - 1
|
|
930
944
|
)
|
|
931
945
|
|
|
932
|
-
# Apply random dropout if configured
|
|
933
|
-
if self.dynamic_biome_spawn_empty > 0:
|
|
934
|
-
dropout_mask = jax.random.bernoulli(
|
|
935
|
-
dropout_key, 1.0 - self.dynamic_biome_spawn_empty, object_grid.shape
|
|
936
|
-
)
|
|
937
|
-
object_grid = jnp.where(dropout_mask, object_grid, 0)
|
|
938
|
-
|
|
939
946
|
# Initialize color grid
|
|
940
947
|
color_grid = jnp.full((self.size[1], self.size[0], 3), 255, dtype=jnp.uint8)
|
|
941
948
|
|
foragax/registry.py
CHANGED
|
@@ -147,12 +147,12 @@ ENV_CONFIGS: Dict[str, Dict[str, Any]] = {
|
|
|
147
147
|
"size": (15, 15),
|
|
148
148
|
"aperture_size": None,
|
|
149
149
|
"objects": None,
|
|
150
|
-
"biomes": (Biome(start=(0, 0), stop=(15, 15), object_frequencies=(0.
|
|
150
|
+
"biomes": (Biome(start=(0, 0), stop=(15, 15), object_frequencies=(0.4,)),),
|
|
151
151
|
"nowrap": False,
|
|
152
152
|
"deterministic_spawn": True,
|
|
153
153
|
"dynamic_biomes": True,
|
|
154
154
|
"biome_consumption_threshold": 1000,
|
|
155
|
-
"dynamic_biome_spawn_empty": 0.
|
|
155
|
+
"dynamic_biome_spawn_empty": 0.4,
|
|
156
156
|
},
|
|
157
157
|
"ForagaxTwoBiome-v1": {
|
|
158
158
|
"size": (15, 15),
|
|
@@ -636,9 +636,6 @@ def make(
|
|
|
636
636
|
biome2_deathcap,
|
|
637
637
|
)
|
|
638
638
|
|
|
639
|
-
if env_id == "ForagaxTwoBiome-v16":
|
|
640
|
-
config["teleport_interval"] = 10000
|
|
641
|
-
|
|
642
639
|
# Backward compatibility: map "world" to "object" with full world
|
|
643
640
|
if observation_type == "world":
|
|
644
641
|
# add deprecation warning
|
|
File without changes
|
|
File without changes
|
|
File without changes
|