continual-foragax 0.32.0__py3-none-any.whl → 0.32.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {continual_foragax-0.32.0.dist-info → continual_foragax-0.32.1.dist-info}/METADATA +1 -6
- {continual_foragax-0.32.0.dist-info → continual_foragax-0.32.1.dist-info}/RECORD +6 -6
- foragax/env.py +45 -31
- {continual_foragax-0.32.0.dist-info → continual_foragax-0.32.1.dist-info}/WHEEL +0 -0
- {continual_foragax-0.32.0.dist-info → continual_foragax-0.32.1.dist-info}/entry_points.txt +0 -0
- {continual_foragax-0.32.0.dist-info → continual_foragax-0.32.1.dist-info}/top_level.txt +0 -0
|
@@ -1,17 +1,12 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: continual-foragax
|
|
3
|
-
Version: 0.32.
|
|
3
|
+
Version: 0.32.1
|
|
4
4
|
Summary: A continual reinforcement learning benchmark
|
|
5
5
|
Author-email: Steven Tang <stang5@ualberta.ca>
|
|
6
6
|
Requires-Python: >=3.8
|
|
7
7
|
Description-Content-Type: text/markdown
|
|
8
8
|
Requires-Dist: gymnax
|
|
9
9
|
Requires-Dist: six; python_version < "3.10"
|
|
10
|
-
Provides-Extra: dev
|
|
11
|
-
Requires-Dist: pre-commit; extra == "dev"
|
|
12
|
-
Requires-Dist: pytest; extra == "dev"
|
|
13
|
-
Requires-Dist: pytest-benchmark; extra == "dev"
|
|
14
|
-
Requires-Dist: ruff; extra == "dev"
|
|
15
10
|
|
|
16
11
|
# foragax
|
|
17
12
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
foragax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
2
|
foragax/colors.py,sha256=rqNPiywP4Nvr0POhsGpasRk-nMMTS3DOwFRUgperlUk,2065
|
|
3
|
-
foragax/env.py,sha256=
|
|
3
|
+
foragax/env.py,sha256=Jjo7XypfQf6ePoKYTV3xvolK4qpacuaifWLZB0ke5y8,54559
|
|
4
4
|
foragax/objects.py,sha256=9wv0ZKT89dDkaeVwUwkVo4dwhRVeUxvsTyhoyYKfOEw,26508
|
|
5
5
|
foragax/registry.py,sha256=hfzQHNgX6uoOdbf4_21iH25abQVQZIjBWn7h5bdrSBg,17981
|
|
6
6
|
foragax/rendering.py,sha256=bms7wvBZTofoR-K-2QD2Ggeed7Viw8uwAEiEpEM3eSo,2768
|
|
@@ -128,8 +128,8 @@ foragax/data/ECA_non-blended_custom/TG_SOUID156887.txt,sha256=juzTPgJoJxfqmZkorL
|
|
|
128
128
|
foragax/data/ECA_non-blended_custom/elements.txt,sha256=OtcUBoDAHxuln79BPKGu0tsQxG_5G2BfAX3Ck130kEA,4507
|
|
129
129
|
foragax/data/ECA_non-blended_custom/metadata.txt,sha256=nudnmOCy5cPJfSXt_IjyX0S5-T7NkCZREICZSimqeqc,48260
|
|
130
130
|
foragax/data/ECA_non-blended_custom/sources.txt,sha256=1j3lSmINAoCMqPqFrHfZJriOz6sTYZNOhXzUwvTLas0,20857
|
|
131
|
-
continual_foragax-0.32.
|
|
132
|
-
continual_foragax-0.32.
|
|
133
|
-
continual_foragax-0.32.
|
|
134
|
-
continual_foragax-0.32.
|
|
135
|
-
continual_foragax-0.32.
|
|
131
|
+
continual_foragax-0.32.1.dist-info/METADATA,sha256=ZG39JPQKbUW7ag-vTZtcDfL8Wvt-nCfO-KOCOZMgOIo,4713
|
|
132
|
+
continual_foragax-0.32.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
133
|
+
continual_foragax-0.32.1.dist-info/entry_points.txt,sha256=Qiu6iE_XudrDO_bVAMeA435h4PO9ourt8huvSHiuMPc,41
|
|
134
|
+
continual_foragax-0.32.1.dist-info/top_level.txt,sha256=-z3SDK6RfLIcLI24n8rdbeFzlVY3hunChzlu-v1Fncs,8
|
|
135
|
+
continual_foragax-0.32.1.dist-info/RECORD,,
|
foragax/env.py
CHANGED
|
@@ -506,17 +506,6 @@ class ForagaxEnv(environment.Environment):
|
|
|
506
506
|
lambda: object_state,
|
|
507
507
|
)
|
|
508
508
|
|
|
509
|
-
# Clear color grid when object is collected
|
|
510
|
-
object_state = jax.lax.cond(
|
|
511
|
-
should_collect_now,
|
|
512
|
-
lambda: object_state.replace(
|
|
513
|
-
color=object_state.color.at[pos[1], pos[0]].set(
|
|
514
|
-
jnp.full((3,), 255, dtype=jnp.uint8)
|
|
515
|
-
)
|
|
516
|
-
),
|
|
517
|
-
lambda: object_state,
|
|
518
|
-
)
|
|
519
|
-
|
|
520
509
|
# 3.5. HANDLE OBJECT EXPIRY
|
|
521
510
|
# Only process expiry if there are objects that can expire
|
|
522
511
|
key, object_state = self.expire_objects(key, state, object_state)
|
|
@@ -648,12 +637,6 @@ class ForagaxEnv(environment.Environment):
|
|
|
648
637
|
rand_key,
|
|
649
638
|
)
|
|
650
639
|
|
|
651
|
-
# Clear color grid when object expires
|
|
652
|
-
empty_color = jnp.full((3,), 255, dtype=jnp.uint8)
|
|
653
|
-
new_obj_state = new_obj_state.replace(
|
|
654
|
-
color=new_obj_state.color.at[y, x].set(empty_color)
|
|
655
|
-
)
|
|
656
|
-
|
|
657
640
|
return new_obj_state
|
|
658
641
|
|
|
659
642
|
def no_op():
|
|
@@ -1114,8 +1097,16 @@ class ForagaxEnv(environment.Environment):
|
|
|
1114
1097
|
y_out = (y_coords < 0) | (y_coords >= self.size[1])
|
|
1115
1098
|
x_out = (x_coords < 0) | (x_coords >= self.size[0])
|
|
1116
1099
|
out_of_bounds = y_out | x_out
|
|
1117
|
-
|
|
1118
|
-
|
|
1100
|
+
|
|
1101
|
+
# Handle both object_id grids (2D) and color grids (3D)
|
|
1102
|
+
if len(values.shape) == 3:
|
|
1103
|
+
# Color grid: use PADDING color (0, 0, 0)
|
|
1104
|
+
padding_value = jnp.array([0, 0, 0], dtype=values.dtype)
|
|
1105
|
+
aperture = jnp.where(out_of_bounds[..., None], padding_value, values)
|
|
1106
|
+
else:
|
|
1107
|
+
# Object ID grid: use PADDING index
|
|
1108
|
+
padding_index = self.object_ids[-1]
|
|
1109
|
+
aperture = jnp.where(out_of_bounds, padding_index, values)
|
|
1119
1110
|
else:
|
|
1120
1111
|
aperture = values
|
|
1121
1112
|
|
|
@@ -1124,12 +1115,14 @@ class ForagaxEnv(environment.Environment):
|
|
|
1124
1115
|
def get_obs(self, state: EnvState, params: EnvParams, key=None) -> jax.Array:
|
|
1125
1116
|
"""Get observation based on observation_type and full_world."""
|
|
1126
1117
|
obs_grid = state.object_state.object_id
|
|
1118
|
+
color_grid = state.object_state.color
|
|
1127
1119
|
|
|
1128
1120
|
if self.full_world:
|
|
1129
1121
|
return self._get_world_obs(obs_grid, state)
|
|
1130
1122
|
else:
|
|
1131
1123
|
grid = self._get_aperture(obs_grid, state.pos)
|
|
1132
|
-
|
|
1124
|
+
color_grid = self._get_aperture(color_grid, state.pos)
|
|
1125
|
+
return self._get_aperture_obs(grid, color_grid, state)
|
|
1133
1126
|
|
|
1134
1127
|
def _get_world_obs(self, obs_grid: jax.Array, state: EnvState) -> jax.Array:
|
|
1135
1128
|
"""Get world observation."""
|
|
@@ -1146,12 +1139,14 @@ class ForagaxEnv(environment.Environment):
|
|
|
1146
1139
|
obs = obs.at[state.pos[1], state.pos[0], -1].set(1)
|
|
1147
1140
|
return obs
|
|
1148
1141
|
elif self.observation_type == "rgb":
|
|
1149
|
-
|
|
1150
|
-
|
|
1151
|
-
|
|
1152
|
-
|
|
1153
|
-
|
|
1154
|
-
|
|
1142
|
+
# Use state colors directly (supports dynamic biomes)
|
|
1143
|
+
colors = state.object_state.color / 255.0
|
|
1144
|
+
|
|
1145
|
+
# Mask empty cells (object_id == 0) to white
|
|
1146
|
+
empty_mask = obs_grid == 0
|
|
1147
|
+
white_color = jnp.ones((self.size[1], self.size[0], 3), dtype=jnp.float32)
|
|
1148
|
+
obs = jnp.where(empty_mask[..., None], white_color, colors)
|
|
1149
|
+
|
|
1155
1150
|
return obs
|
|
1156
1151
|
elif self.observation_type == "color":
|
|
1157
1152
|
# Handle case with no objects (only EMPTY)
|
|
@@ -1168,17 +1163,24 @@ class ForagaxEnv(environment.Environment):
|
|
|
1168
1163
|
else:
|
|
1169
1164
|
raise ValueError(f"Unknown observation_type: {self.observation_type}")
|
|
1170
1165
|
|
|
1171
|
-
def _get_aperture_obs(
|
|
1166
|
+
def _get_aperture_obs(
|
|
1167
|
+
self, aperture: jax.Array, color_aperture: jax.Array, state: EnvState
|
|
1168
|
+
) -> jax.Array:
|
|
1172
1169
|
"""Get aperture observation."""
|
|
1173
1170
|
if self.observation_type == "object":
|
|
1174
1171
|
num_obj_types = len(self.object_ids)
|
|
1175
1172
|
obs = jax.nn.one_hot(aperture, num_obj_types, axis=-1)
|
|
1176
1173
|
return obs
|
|
1177
1174
|
elif self.observation_type == "rgb":
|
|
1178
|
-
|
|
1179
|
-
|
|
1180
|
-
|
|
1181
|
-
|
|
1175
|
+
# Use the color aperture that was passed in
|
|
1176
|
+
aperture_colors = color_aperture / 255.0
|
|
1177
|
+
|
|
1178
|
+
# Mask empty cells (object_id == 0) to white
|
|
1179
|
+
empty_mask = aperture == 0
|
|
1180
|
+
white_color = jnp.ones(aperture_colors.shape, dtype=jnp.float32)
|
|
1181
|
+
|
|
1182
|
+
obs = jnp.where(empty_mask[..., None], white_color, aperture_colors)
|
|
1183
|
+
|
|
1182
1184
|
return obs
|
|
1183
1185
|
elif self.observation_type == "color":
|
|
1184
1186
|
# Handle case with no objects (only EMPTY)
|
|
@@ -1229,6 +1231,10 @@ class ForagaxEnv(environment.Environment):
|
|
|
1229
1231
|
if self.dynamic_biomes:
|
|
1230
1232
|
# Use per-instance colors from state
|
|
1231
1233
|
img = state.object_state.color.copy()
|
|
1234
|
+
# Mask empty cells (object_id == 0) to white
|
|
1235
|
+
empty_mask = state.object_state.object_id == 0
|
|
1236
|
+
white_color = jnp.array([255, 255, 255], dtype=jnp.uint8)
|
|
1237
|
+
img = jnp.where(empty_mask[..., None], white_color, img)
|
|
1232
1238
|
else:
|
|
1233
1239
|
# Use default object colors
|
|
1234
1240
|
img = jnp.zeros((self.size[1], self.size[0], 3))
|
|
@@ -1297,6 +1303,14 @@ class ForagaxEnv(environment.Environment):
|
|
|
1297
1303
|
)
|
|
1298
1304
|
img = state.object_state.color[y_coords_adj, x_coords_adj]
|
|
1299
1305
|
|
|
1306
|
+
# Mask empty cells (object_id == 0) to white
|
|
1307
|
+
aperture_object_ids = state.object_state.object_id[
|
|
1308
|
+
y_coords_adj, x_coords_adj
|
|
1309
|
+
]
|
|
1310
|
+
empty_mask = aperture_object_ids == 0
|
|
1311
|
+
white_color = jnp.array([255, 255, 255], dtype=jnp.uint8)
|
|
1312
|
+
img = jnp.where(empty_mask[..., None], white_color, img)
|
|
1313
|
+
|
|
1300
1314
|
if self.nowrap:
|
|
1301
1315
|
# For out-of-bounds, use padding object color
|
|
1302
1316
|
y_out = (y_coords < 0) | (y_coords >= self.size[1])
|
|
File without changes
|
|
File without changes
|
|
File without changes
|