continual-foragax 0.32.0__py3-none-any.whl → 0.33.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {continual_foragax-0.32.0.dist-info → continual_foragax-0.33.0.dist-info}/METADATA +1 -6
- {continual_foragax-0.32.0.dist-info → continual_foragax-0.33.0.dist-info}/RECORD +6 -6
- foragax/env.py +62 -31
- {continual_foragax-0.32.0.dist-info → continual_foragax-0.33.0.dist-info}/WHEEL +0 -0
- {continual_foragax-0.32.0.dist-info → continual_foragax-0.33.0.dist-info}/entry_points.txt +0 -0
- {continual_foragax-0.32.0.dist-info → continual_foragax-0.33.0.dist-info}/top_level.txt +0 -0
|
@@ -1,17 +1,12 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: continual-foragax
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.33.0
|
|
4
4
|
Summary: A continual reinforcement learning benchmark
|
|
5
5
|
Author-email: Steven Tang <stang5@ualberta.ca>
|
|
6
6
|
Requires-Python: >=3.8
|
|
7
7
|
Description-Content-Type: text/markdown
|
|
8
8
|
Requires-Dist: gymnax
|
|
9
9
|
Requires-Dist: six; python_version < "3.10"
|
|
10
|
-
Provides-Extra: dev
|
|
11
|
-
Requires-Dist: pre-commit; extra == "dev"
|
|
12
|
-
Requires-Dist: pytest; extra == "dev"
|
|
13
|
-
Requires-Dist: pytest-benchmark; extra == "dev"
|
|
14
|
-
Requires-Dist: ruff; extra == "dev"
|
|
15
10
|
|
|
16
11
|
# foragax
|
|
17
12
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
foragax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
2
|
foragax/colors.py,sha256=rqNPiywP4Nvr0POhsGpasRk-nMMTS3DOwFRUgperlUk,2065
|
|
3
|
-
foragax/env.py,sha256=
|
|
3
|
+
foragax/env.py,sha256=kf0B0D5CMUnOMiZ8diBlmlt_vf2Wh-M4gSNoeo4jfHY,55147
|
|
4
4
|
foragax/objects.py,sha256=9wv0ZKT89dDkaeVwUwkVo4dwhRVeUxvsTyhoyYKfOEw,26508
|
|
5
5
|
foragax/registry.py,sha256=hfzQHNgX6uoOdbf4_21iH25abQVQZIjBWn7h5bdrSBg,17981
|
|
6
6
|
foragax/rendering.py,sha256=bms7wvBZTofoR-K-2QD2Ggeed7Viw8uwAEiEpEM3eSo,2768
|
|
@@ -128,8 +128,8 @@ foragax/data/ECA_non-blended_custom/TG_SOUID156887.txt,sha256=juzTPgJoJxfqmZkorL
|
|
|
128
128
|
foragax/data/ECA_non-blended_custom/elements.txt,sha256=OtcUBoDAHxuln79BPKGu0tsQxG_5G2BfAX3Ck130kEA,4507
|
|
129
129
|
foragax/data/ECA_non-blended_custom/metadata.txt,sha256=nudnmOCy5cPJfSXt_IjyX0S5-T7NkCZREICZSimqeqc,48260
|
|
130
130
|
foragax/data/ECA_non-blended_custom/sources.txt,sha256=1j3lSmINAoCMqPqFrHfZJriOz6sTYZNOhXzUwvTLas0,20857
|
|
131
|
-
continual_foragax-0.
|
|
132
|
-
continual_foragax-0.
|
|
133
|
-
continual_foragax-0.
|
|
134
|
-
continual_foragax-0.
|
|
135
|
-
continual_foragax-0.
|
|
131
|
+
continual_foragax-0.33.0.dist-info/METADATA,sha256=vEOZLNVNPhccIZDIrN-puYjVPQKWHtThbhoVdUjhF4A,4713
|
|
132
|
+
continual_foragax-0.33.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
133
|
+
continual_foragax-0.33.0.dist-info/entry_points.txt,sha256=Qiu6iE_XudrDO_bVAMeA435h4PO9ourt8huvSHiuMPc,41
|
|
134
|
+
continual_foragax-0.33.0.dist-info/top_level.txt,sha256=-z3SDK6RfLIcLI24n8rdbeFzlVY3hunChzlu-v1Fncs,8
|
|
135
|
+
continual_foragax-0.33.0.dist-info/RECORD,,
|
foragax/env.py
CHANGED
|
@@ -506,17 +506,6 @@ class ForagaxEnv(environment.Environment):
|
|
|
506
506
|
lambda: object_state,
|
|
507
507
|
)
|
|
508
508
|
|
|
509
|
-
# Clear color grid when object is collected
|
|
510
|
-
object_state = jax.lax.cond(
|
|
511
|
-
should_collect_now,
|
|
512
|
-
lambda: object_state.replace(
|
|
513
|
-
color=object_state.color.at[pos[1], pos[0]].set(
|
|
514
|
-
jnp.full((3,), 255, dtype=jnp.uint8)
|
|
515
|
-
)
|
|
516
|
-
),
|
|
517
|
-
lambda: object_state,
|
|
518
|
-
)
|
|
519
|
-
|
|
520
509
|
# 3.5. HANDLE OBJECT EXPIRY
|
|
521
510
|
# Only process expiry if there are objects that can expire
|
|
522
511
|
key, object_state = self.expire_objects(key, state, object_state)
|
|
@@ -564,6 +553,23 @@ class ForagaxEnv(environment.Environment):
|
|
|
564
553
|
info["biome_id"] = object_state.biome_id[pos[1], pos[0]]
|
|
565
554
|
info["object_collected_id"] = jax.lax.select(should_collect, obj_at_pos, -1)
|
|
566
555
|
|
|
556
|
+
# Compute reward at each grid position
|
|
557
|
+
fixed_key = jax.random.key(0) # Fixed key for deterministic reward computation
|
|
558
|
+
|
|
559
|
+
def compute_reward(obj_id, params):
|
|
560
|
+
return jax.lax.cond(
|
|
561
|
+
obj_id > 0,
|
|
562
|
+
lambda: jax.lax.switch(
|
|
563
|
+
obj_id, self.reward_fns, state.time, fixed_key, params
|
|
564
|
+
),
|
|
565
|
+
lambda: 0.0,
|
|
566
|
+
)
|
|
567
|
+
|
|
568
|
+
reward_grid = jax.vmap(jax.vmap(compute_reward))(
|
|
569
|
+
object_state.object_id, object_state.state_params
|
|
570
|
+
)
|
|
571
|
+
info["rewards"] = reward_grid
|
|
572
|
+
|
|
567
573
|
# 4. UPDATE STATE
|
|
568
574
|
state = EnvState(
|
|
569
575
|
pos=pos,
|
|
@@ -648,12 +654,6 @@ class ForagaxEnv(environment.Environment):
|
|
|
648
654
|
rand_key,
|
|
649
655
|
)
|
|
650
656
|
|
|
651
|
-
# Clear color grid when object expires
|
|
652
|
-
empty_color = jnp.full((3,), 255, dtype=jnp.uint8)
|
|
653
|
-
new_obj_state = new_obj_state.replace(
|
|
654
|
-
color=new_obj_state.color.at[y, x].set(empty_color)
|
|
655
|
-
)
|
|
656
|
-
|
|
657
657
|
return new_obj_state
|
|
658
658
|
|
|
659
659
|
def no_op():
|
|
@@ -1114,8 +1114,16 @@ class ForagaxEnv(environment.Environment):
|
|
|
1114
1114
|
y_out = (y_coords < 0) | (y_coords >= self.size[1])
|
|
1115
1115
|
x_out = (x_coords < 0) | (x_coords >= self.size[0])
|
|
1116
1116
|
out_of_bounds = y_out | x_out
|
|
1117
|
-
|
|
1118
|
-
|
|
1117
|
+
|
|
1118
|
+
# Handle both object_id grids (2D) and color grids (3D)
|
|
1119
|
+
if len(values.shape) == 3:
|
|
1120
|
+
# Color grid: use PADDING color (0, 0, 0)
|
|
1121
|
+
padding_value = jnp.array([0, 0, 0], dtype=values.dtype)
|
|
1122
|
+
aperture = jnp.where(out_of_bounds[..., None], padding_value, values)
|
|
1123
|
+
else:
|
|
1124
|
+
# Object ID grid: use PADDING index
|
|
1125
|
+
padding_index = self.object_ids[-1]
|
|
1126
|
+
aperture = jnp.where(out_of_bounds, padding_index, values)
|
|
1119
1127
|
else:
|
|
1120
1128
|
aperture = values
|
|
1121
1129
|
|
|
@@ -1124,12 +1132,14 @@ class ForagaxEnv(environment.Environment):
|
|
|
1124
1132
|
def get_obs(self, state: EnvState, params: EnvParams, key=None) -> jax.Array:
|
|
1125
1133
|
"""Get observation based on observation_type and full_world."""
|
|
1126
1134
|
obs_grid = state.object_state.object_id
|
|
1135
|
+
color_grid = state.object_state.color
|
|
1127
1136
|
|
|
1128
1137
|
if self.full_world:
|
|
1129
1138
|
return self._get_world_obs(obs_grid, state)
|
|
1130
1139
|
else:
|
|
1131
1140
|
grid = self._get_aperture(obs_grid, state.pos)
|
|
1132
|
-
|
|
1141
|
+
color_grid = self._get_aperture(color_grid, state.pos)
|
|
1142
|
+
return self._get_aperture_obs(grid, color_grid, state)
|
|
1133
1143
|
|
|
1134
1144
|
def _get_world_obs(self, obs_grid: jax.Array, state: EnvState) -> jax.Array:
|
|
1135
1145
|
"""Get world observation."""
|
|
@@ -1146,12 +1156,14 @@ class ForagaxEnv(environment.Environment):
|
|
|
1146
1156
|
obs = obs.at[state.pos[1], state.pos[0], -1].set(1)
|
|
1147
1157
|
return obs
|
|
1148
1158
|
elif self.observation_type == "rgb":
|
|
1149
|
-
|
|
1150
|
-
|
|
1151
|
-
|
|
1152
|
-
|
|
1153
|
-
|
|
1154
|
-
|
|
1159
|
+
# Use state colors directly (supports dynamic biomes)
|
|
1160
|
+
colors = state.object_state.color / 255.0
|
|
1161
|
+
|
|
1162
|
+
# Mask empty cells (object_id == 0) to white
|
|
1163
|
+
empty_mask = obs_grid == 0
|
|
1164
|
+
white_color = jnp.ones((self.size[1], self.size[0], 3), dtype=jnp.float32)
|
|
1165
|
+
obs = jnp.where(empty_mask[..., None], white_color, colors)
|
|
1166
|
+
|
|
1155
1167
|
return obs
|
|
1156
1168
|
elif self.observation_type == "color":
|
|
1157
1169
|
# Handle case with no objects (only EMPTY)
|
|
@@ -1168,17 +1180,24 @@ class ForagaxEnv(environment.Environment):
|
|
|
1168
1180
|
else:
|
|
1169
1181
|
raise ValueError(f"Unknown observation_type: {self.observation_type}")
|
|
1170
1182
|
|
|
1171
|
-
def _get_aperture_obs(
|
|
1183
|
+
def _get_aperture_obs(
|
|
1184
|
+
self, aperture: jax.Array, color_aperture: jax.Array, state: EnvState
|
|
1185
|
+
) -> jax.Array:
|
|
1172
1186
|
"""Get aperture observation."""
|
|
1173
1187
|
if self.observation_type == "object":
|
|
1174
1188
|
num_obj_types = len(self.object_ids)
|
|
1175
1189
|
obs = jax.nn.one_hot(aperture, num_obj_types, axis=-1)
|
|
1176
1190
|
return obs
|
|
1177
1191
|
elif self.observation_type == "rgb":
|
|
1178
|
-
|
|
1179
|
-
|
|
1180
|
-
|
|
1181
|
-
|
|
1192
|
+
# Use the color aperture that was passed in
|
|
1193
|
+
aperture_colors = color_aperture / 255.0
|
|
1194
|
+
|
|
1195
|
+
# Mask empty cells (object_id == 0) to white
|
|
1196
|
+
empty_mask = aperture == 0
|
|
1197
|
+
white_color = jnp.ones(aperture_colors.shape, dtype=jnp.float32)
|
|
1198
|
+
|
|
1199
|
+
obs = jnp.where(empty_mask[..., None], white_color, aperture_colors)
|
|
1200
|
+
|
|
1182
1201
|
return obs
|
|
1183
1202
|
elif self.observation_type == "color":
|
|
1184
1203
|
# Handle case with no objects (only EMPTY)
|
|
@@ -1229,6 +1248,10 @@ class ForagaxEnv(environment.Environment):
|
|
|
1229
1248
|
if self.dynamic_biomes:
|
|
1230
1249
|
# Use per-instance colors from state
|
|
1231
1250
|
img = state.object_state.color.copy()
|
|
1251
|
+
# Mask empty cells (object_id == 0) to white
|
|
1252
|
+
empty_mask = state.object_state.object_id == 0
|
|
1253
|
+
white_color = jnp.array([255, 255, 255], dtype=jnp.uint8)
|
|
1254
|
+
img = jnp.where(empty_mask[..., None], white_color, img)
|
|
1232
1255
|
else:
|
|
1233
1256
|
# Use default object colors
|
|
1234
1257
|
img = jnp.zeros((self.size[1], self.size[0], 3))
|
|
@@ -1297,6 +1320,14 @@ class ForagaxEnv(environment.Environment):
|
|
|
1297
1320
|
)
|
|
1298
1321
|
img = state.object_state.color[y_coords_adj, x_coords_adj]
|
|
1299
1322
|
|
|
1323
|
+
# Mask empty cells (object_id == 0) to white
|
|
1324
|
+
aperture_object_ids = state.object_state.object_id[
|
|
1325
|
+
y_coords_adj, x_coords_adj
|
|
1326
|
+
]
|
|
1327
|
+
empty_mask = aperture_object_ids == 0
|
|
1328
|
+
white_color = jnp.array([255, 255, 255], dtype=jnp.uint8)
|
|
1329
|
+
img = jnp.where(empty_mask[..., None], white_color, img)
|
|
1330
|
+
|
|
1300
1331
|
if self.nowrap:
|
|
1301
1332
|
# For out-of-bounds, use padding object color
|
|
1302
1333
|
y_out = (y_coords < 0) | (y_coords >= self.size[1])
|
|
File without changes
|
|
File without changes
|
|
File without changes
|