continual-foragax 0.32.0__py3-none-any.whl → 0.32.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,17 +1,12 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: continual-foragax
3
- Version: 0.32.0
3
+ Version: 0.32.1
4
4
  Summary: A continual reinforcement learning benchmark
5
5
  Author-email: Steven Tang <stang5@ualberta.ca>
6
6
  Requires-Python: >=3.8
7
7
  Description-Content-Type: text/markdown
8
8
  Requires-Dist: gymnax
9
9
  Requires-Dist: six; python_version < "3.10"
10
- Provides-Extra: dev
11
- Requires-Dist: pre-commit; extra == "dev"
12
- Requires-Dist: pytest; extra == "dev"
13
- Requires-Dist: pytest-benchmark; extra == "dev"
14
- Requires-Dist: ruff; extra == "dev"
15
10
 
16
11
  # foragax
17
12
 
@@ -1,6 +1,6 @@
1
1
  foragax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  foragax/colors.py,sha256=rqNPiywP4Nvr0POhsGpasRk-nMMTS3DOwFRUgperlUk,2065
3
- foragax/env.py,sha256=OgZegbHnmVCchSEBHfwm4Kgc4rrwTEnx6r4WdAvK_P4,53800
3
+ foragax/env.py,sha256=Jjo7XypfQf6ePoKYTV3xvolK4qpacuaifWLZB0ke5y8,54559
4
4
  foragax/objects.py,sha256=9wv0ZKT89dDkaeVwUwkVo4dwhRVeUxvsTyhoyYKfOEw,26508
5
5
  foragax/registry.py,sha256=hfzQHNgX6uoOdbf4_21iH25abQVQZIjBWn7h5bdrSBg,17981
6
6
  foragax/rendering.py,sha256=bms7wvBZTofoR-K-2QD2Ggeed7Viw8uwAEiEpEM3eSo,2768
@@ -128,8 +128,8 @@ foragax/data/ECA_non-blended_custom/TG_SOUID156887.txt,sha256=juzTPgJoJxfqmZkorL
128
128
  foragax/data/ECA_non-blended_custom/elements.txt,sha256=OtcUBoDAHxuln79BPKGu0tsQxG_5G2BfAX3Ck130kEA,4507
129
129
  foragax/data/ECA_non-blended_custom/metadata.txt,sha256=nudnmOCy5cPJfSXt_IjyX0S5-T7NkCZREICZSimqeqc,48260
130
130
  foragax/data/ECA_non-blended_custom/sources.txt,sha256=1j3lSmINAoCMqPqFrHfZJriOz6sTYZNOhXzUwvTLas0,20857
131
- continual_foragax-0.32.0.dist-info/METADATA,sha256=2FLbgAsQJg-W3DbOvDap6PRb5_ku7g8mfuXFuUZ2Ybs,4897
132
- continual_foragax-0.32.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
133
- continual_foragax-0.32.0.dist-info/entry_points.txt,sha256=Qiu6iE_XudrDO_bVAMeA435h4PO9ourt8huvSHiuMPc,41
134
- continual_foragax-0.32.0.dist-info/top_level.txt,sha256=-z3SDK6RfLIcLI24n8rdbeFzlVY3hunChzlu-v1Fncs,8
135
- continual_foragax-0.32.0.dist-info/RECORD,,
131
+ continual_foragax-0.32.1.dist-info/METADATA,sha256=ZG39JPQKbUW7ag-vTZtcDfL8Wvt-nCfO-KOCOZMgOIo,4713
132
+ continual_foragax-0.32.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
133
+ continual_foragax-0.32.1.dist-info/entry_points.txt,sha256=Qiu6iE_XudrDO_bVAMeA435h4PO9ourt8huvSHiuMPc,41
134
+ continual_foragax-0.32.1.dist-info/top_level.txt,sha256=-z3SDK6RfLIcLI24n8rdbeFzlVY3hunChzlu-v1Fncs,8
135
+ continual_foragax-0.32.1.dist-info/RECORD,,
foragax/env.py CHANGED
@@ -506,17 +506,6 @@ class ForagaxEnv(environment.Environment):
506
506
  lambda: object_state,
507
507
  )
508
508
 
509
- # Clear color grid when object is collected
510
- object_state = jax.lax.cond(
511
- should_collect_now,
512
- lambda: object_state.replace(
513
- color=object_state.color.at[pos[1], pos[0]].set(
514
- jnp.full((3,), 255, dtype=jnp.uint8)
515
- )
516
- ),
517
- lambda: object_state,
518
- )
519
-
520
509
  # 3.5. HANDLE OBJECT EXPIRY
521
510
  # Only process expiry if there are objects that can expire
522
511
  key, object_state = self.expire_objects(key, state, object_state)
@@ -648,12 +637,6 @@ class ForagaxEnv(environment.Environment):
648
637
  rand_key,
649
638
  )
650
639
 
651
- # Clear color grid when object expires
652
- empty_color = jnp.full((3,), 255, dtype=jnp.uint8)
653
- new_obj_state = new_obj_state.replace(
654
- color=new_obj_state.color.at[y, x].set(empty_color)
655
- )
656
-
657
640
  return new_obj_state
658
641
 
659
642
  def no_op():
@@ -1114,8 +1097,16 @@ class ForagaxEnv(environment.Environment):
1114
1097
  y_out = (y_coords < 0) | (y_coords >= self.size[1])
1115
1098
  x_out = (x_coords < 0) | (x_coords >= self.size[0])
1116
1099
  out_of_bounds = y_out | x_out
1117
- padding_index = self.object_ids[-1]
1118
- aperture = jnp.where(out_of_bounds, padding_index, values)
1100
+
1101
+ # Handle both object_id grids (2D) and color grids (3D)
1102
+ if len(values.shape) == 3:
1103
+ # Color grid: use PADDING color (0, 0, 0)
1104
+ padding_value = jnp.array([0, 0, 0], dtype=values.dtype)
1105
+ aperture = jnp.where(out_of_bounds[..., None], padding_value, values)
1106
+ else:
1107
+ # Object ID grid: use PADDING index
1108
+ padding_index = self.object_ids[-1]
1109
+ aperture = jnp.where(out_of_bounds, padding_index, values)
1119
1110
  else:
1120
1111
  aperture = values
1121
1112
 
@@ -1124,12 +1115,14 @@ class ForagaxEnv(environment.Environment):
1124
1115
  def get_obs(self, state: EnvState, params: EnvParams, key=None) -> jax.Array:
1125
1116
  """Get observation based on observation_type and full_world."""
1126
1117
  obs_grid = state.object_state.object_id
1118
+ color_grid = state.object_state.color
1127
1119
 
1128
1120
  if self.full_world:
1129
1121
  return self._get_world_obs(obs_grid, state)
1130
1122
  else:
1131
1123
  grid = self._get_aperture(obs_grid, state.pos)
1132
- return self._get_aperture_obs(grid, state)
1124
+ color_grid = self._get_aperture(color_grid, state.pos)
1125
+ return self._get_aperture_obs(grid, color_grid, state)
1133
1126
 
1134
1127
  def _get_world_obs(self, obs_grid: jax.Array, state: EnvState) -> jax.Array:
1135
1128
  """Get world observation."""
@@ -1146,12 +1139,14 @@ class ForagaxEnv(environment.Environment):
1146
1139
  obs = obs.at[state.pos[1], state.pos[0], -1].set(1)
1147
1140
  return obs
1148
1141
  elif self.observation_type == "rgb":
1149
- obs = jax.nn.one_hot(obs_grid, num_obj_types)
1150
- # Agent position
1151
- obs = obs.at[state.pos[1], state.pos[0], :].set(0)
1152
- obs = obs.at[state.pos[1], state.pos[0], -1].set(1)
1153
- colors = self.object_colors / 255.0
1154
- obs = jnp.tensordot(obs, colors, axes=1)
1142
+ # Use state colors directly (supports dynamic biomes)
1143
+ colors = state.object_state.color / 255.0
1144
+
1145
+ # Mask empty cells (object_id == 0) to white
1146
+ empty_mask = obs_grid == 0
1147
+ white_color = jnp.ones((self.size[1], self.size[0], 3), dtype=jnp.float32)
1148
+ obs = jnp.where(empty_mask[..., None], white_color, colors)
1149
+
1155
1150
  return obs
1156
1151
  elif self.observation_type == "color":
1157
1152
  # Handle case with no objects (only EMPTY)
@@ -1168,17 +1163,24 @@ class ForagaxEnv(environment.Environment):
1168
1163
  else:
1169
1164
  raise ValueError(f"Unknown observation_type: {self.observation_type}")
1170
1165
 
1171
- def _get_aperture_obs(self, aperture: jax.Array, state: EnvState) -> jax.Array:
1166
+ def _get_aperture_obs(
1167
+ self, aperture: jax.Array, color_aperture: jax.Array, state: EnvState
1168
+ ) -> jax.Array:
1172
1169
  """Get aperture observation."""
1173
1170
  if self.observation_type == "object":
1174
1171
  num_obj_types = len(self.object_ids)
1175
1172
  obs = jax.nn.one_hot(aperture, num_obj_types, axis=-1)
1176
1173
  return obs
1177
1174
  elif self.observation_type == "rgb":
1178
- num_obj_types = len(self.object_ids)
1179
- aperture_one_hot = jax.nn.one_hot(aperture, num_obj_types)
1180
- colors = self.object_colors / 255.0
1181
- obs = jnp.tensordot(aperture_one_hot, colors, axes=1)
1175
+ # Use the color aperture that was passed in
1176
+ aperture_colors = color_aperture / 255.0
1177
+
1178
+ # Mask empty cells (object_id == 0) to white
1179
+ empty_mask = aperture == 0
1180
+ white_color = jnp.ones(aperture_colors.shape, dtype=jnp.float32)
1181
+
1182
+ obs = jnp.where(empty_mask[..., None], white_color, aperture_colors)
1183
+
1182
1184
  return obs
1183
1185
  elif self.observation_type == "color":
1184
1186
  # Handle case with no objects (only EMPTY)
@@ -1229,6 +1231,10 @@ class ForagaxEnv(environment.Environment):
1229
1231
  if self.dynamic_biomes:
1230
1232
  # Use per-instance colors from state
1231
1233
  img = state.object_state.color.copy()
1234
+ # Mask empty cells (object_id == 0) to white
1235
+ empty_mask = state.object_state.object_id == 0
1236
+ white_color = jnp.array([255, 255, 255], dtype=jnp.uint8)
1237
+ img = jnp.where(empty_mask[..., None], white_color, img)
1232
1238
  else:
1233
1239
  # Use default object colors
1234
1240
  img = jnp.zeros((self.size[1], self.size[0], 3))
@@ -1297,6 +1303,14 @@ class ForagaxEnv(environment.Environment):
1297
1303
  )
1298
1304
  img = state.object_state.color[y_coords_adj, x_coords_adj]
1299
1305
 
1306
+ # Mask empty cells (object_id == 0) to white
1307
+ aperture_object_ids = state.object_state.object_id[
1308
+ y_coords_adj, x_coords_adj
1309
+ ]
1310
+ empty_mask = aperture_object_ids == 0
1311
+ white_color = jnp.array([255, 255, 255], dtype=jnp.uint8)
1312
+ img = jnp.where(empty_mask[..., None], white_color, img)
1313
+
1300
1314
  if self.nowrap:
1301
1315
  # For out-of-bounds, use padding object color
1302
1316
  y_out = (y_coords < 0) | (y_coords >= self.size[1])