continual-foragax 0.32.0__py3-none-any.whl → 0.33.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,17 +1,12 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: continual-foragax
3
- Version: 0.32.0
3
+ Version: 0.33.0
4
4
  Summary: A continual reinforcement learning benchmark
5
5
  Author-email: Steven Tang <stang5@ualberta.ca>
6
6
  Requires-Python: >=3.8
7
7
  Description-Content-Type: text/markdown
8
8
  Requires-Dist: gymnax
9
9
  Requires-Dist: six; python_version < "3.10"
10
- Provides-Extra: dev
11
- Requires-Dist: pre-commit; extra == "dev"
12
- Requires-Dist: pytest; extra == "dev"
13
- Requires-Dist: pytest-benchmark; extra == "dev"
14
- Requires-Dist: ruff; extra == "dev"
15
10
 
16
11
  # foragax
17
12
 
@@ -1,6 +1,6 @@
1
1
  foragax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  foragax/colors.py,sha256=rqNPiywP4Nvr0POhsGpasRk-nMMTS3DOwFRUgperlUk,2065
3
- foragax/env.py,sha256=OgZegbHnmVCchSEBHfwm4Kgc4rrwTEnx6r4WdAvK_P4,53800
3
+ foragax/env.py,sha256=kf0B0D5CMUnOMiZ8diBlmlt_vf2Wh-M4gSNoeo4jfHY,55147
4
4
  foragax/objects.py,sha256=9wv0ZKT89dDkaeVwUwkVo4dwhRVeUxvsTyhoyYKfOEw,26508
5
5
  foragax/registry.py,sha256=hfzQHNgX6uoOdbf4_21iH25abQVQZIjBWn7h5bdrSBg,17981
6
6
  foragax/rendering.py,sha256=bms7wvBZTofoR-K-2QD2Ggeed7Viw8uwAEiEpEM3eSo,2768
@@ -128,8 +128,8 @@ foragax/data/ECA_non-blended_custom/TG_SOUID156887.txt,sha256=juzTPgJoJxfqmZkorL
128
128
  foragax/data/ECA_non-blended_custom/elements.txt,sha256=OtcUBoDAHxuln79BPKGu0tsQxG_5G2BfAX3Ck130kEA,4507
129
129
  foragax/data/ECA_non-blended_custom/metadata.txt,sha256=nudnmOCy5cPJfSXt_IjyX0S5-T7NkCZREICZSimqeqc,48260
130
130
  foragax/data/ECA_non-blended_custom/sources.txt,sha256=1j3lSmINAoCMqPqFrHfZJriOz6sTYZNOhXzUwvTLas0,20857
131
- continual_foragax-0.32.0.dist-info/METADATA,sha256=2FLbgAsQJg-W3DbOvDap6PRb5_ku7g8mfuXFuUZ2Ybs,4897
132
- continual_foragax-0.32.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
133
- continual_foragax-0.32.0.dist-info/entry_points.txt,sha256=Qiu6iE_XudrDO_bVAMeA435h4PO9ourt8huvSHiuMPc,41
134
- continual_foragax-0.32.0.dist-info/top_level.txt,sha256=-z3SDK6RfLIcLI24n8rdbeFzlVY3hunChzlu-v1Fncs,8
135
- continual_foragax-0.32.0.dist-info/RECORD,,
131
+ continual_foragax-0.33.0.dist-info/METADATA,sha256=vEOZLNVNPhccIZDIrN-puYjVPQKWHtThbhoVdUjhF4A,4713
132
+ continual_foragax-0.33.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
133
+ continual_foragax-0.33.0.dist-info/entry_points.txt,sha256=Qiu6iE_XudrDO_bVAMeA435h4PO9ourt8huvSHiuMPc,41
134
+ continual_foragax-0.33.0.dist-info/top_level.txt,sha256=-z3SDK6RfLIcLI24n8rdbeFzlVY3hunChzlu-v1Fncs,8
135
+ continual_foragax-0.33.0.dist-info/RECORD,,
foragax/env.py CHANGED
@@ -506,17 +506,6 @@ class ForagaxEnv(environment.Environment):
506
506
  lambda: object_state,
507
507
  )
508
508
 
509
- # Clear color grid when object is collected
510
- object_state = jax.lax.cond(
511
- should_collect_now,
512
- lambda: object_state.replace(
513
- color=object_state.color.at[pos[1], pos[0]].set(
514
- jnp.full((3,), 255, dtype=jnp.uint8)
515
- )
516
- ),
517
- lambda: object_state,
518
- )
519
-
520
509
  # 3.5. HANDLE OBJECT EXPIRY
521
510
  # Only process expiry if there are objects that can expire
522
511
  key, object_state = self.expire_objects(key, state, object_state)
@@ -564,6 +553,23 @@ class ForagaxEnv(environment.Environment):
564
553
  info["biome_id"] = object_state.biome_id[pos[1], pos[0]]
565
554
  info["object_collected_id"] = jax.lax.select(should_collect, obj_at_pos, -1)
566
555
 
556
+ # Compute reward at each grid position
557
+ fixed_key = jax.random.key(0) # Fixed key for deterministic reward computation
558
+
559
+ def compute_reward(obj_id, params):
560
+ return jax.lax.cond(
561
+ obj_id > 0,
562
+ lambda: jax.lax.switch(
563
+ obj_id, self.reward_fns, state.time, fixed_key, params
564
+ ),
565
+ lambda: 0.0,
566
+ )
567
+
568
+ reward_grid = jax.vmap(jax.vmap(compute_reward))(
569
+ object_state.object_id, object_state.state_params
570
+ )
571
+ info["rewards"] = reward_grid
572
+
567
573
  # 4. UPDATE STATE
568
574
  state = EnvState(
569
575
  pos=pos,
@@ -648,12 +654,6 @@ class ForagaxEnv(environment.Environment):
648
654
  rand_key,
649
655
  )
650
656
 
651
- # Clear color grid when object expires
652
- empty_color = jnp.full((3,), 255, dtype=jnp.uint8)
653
- new_obj_state = new_obj_state.replace(
654
- color=new_obj_state.color.at[y, x].set(empty_color)
655
- )
656
-
657
657
  return new_obj_state
658
658
 
659
659
  def no_op():
@@ -1114,8 +1114,16 @@ class ForagaxEnv(environment.Environment):
1114
1114
  y_out = (y_coords < 0) | (y_coords >= self.size[1])
1115
1115
  x_out = (x_coords < 0) | (x_coords >= self.size[0])
1116
1116
  out_of_bounds = y_out | x_out
1117
- padding_index = self.object_ids[-1]
1118
- aperture = jnp.where(out_of_bounds, padding_index, values)
1117
+
1118
+ # Handle both object_id grids (2D) and color grids (3D)
1119
+ if len(values.shape) == 3:
1120
+ # Color grid: use PADDING color (0, 0, 0)
1121
+ padding_value = jnp.array([0, 0, 0], dtype=values.dtype)
1122
+ aperture = jnp.where(out_of_bounds[..., None], padding_value, values)
1123
+ else:
1124
+ # Object ID grid: use PADDING index
1125
+ padding_index = self.object_ids[-1]
1126
+ aperture = jnp.where(out_of_bounds, padding_index, values)
1119
1127
  else:
1120
1128
  aperture = values
1121
1129
 
@@ -1124,12 +1132,14 @@ class ForagaxEnv(environment.Environment):
1124
1132
  def get_obs(self, state: EnvState, params: EnvParams, key=None) -> jax.Array:
1125
1133
  """Get observation based on observation_type and full_world."""
1126
1134
  obs_grid = state.object_state.object_id
1135
+ color_grid = state.object_state.color
1127
1136
 
1128
1137
  if self.full_world:
1129
1138
  return self._get_world_obs(obs_grid, state)
1130
1139
  else:
1131
1140
  grid = self._get_aperture(obs_grid, state.pos)
1132
- return self._get_aperture_obs(grid, state)
1141
+ color_grid = self._get_aperture(color_grid, state.pos)
1142
+ return self._get_aperture_obs(grid, color_grid, state)
1133
1143
 
1134
1144
  def _get_world_obs(self, obs_grid: jax.Array, state: EnvState) -> jax.Array:
1135
1145
  """Get world observation."""
@@ -1146,12 +1156,14 @@ class ForagaxEnv(environment.Environment):
1146
1156
  obs = obs.at[state.pos[1], state.pos[0], -1].set(1)
1147
1157
  return obs
1148
1158
  elif self.observation_type == "rgb":
1149
- obs = jax.nn.one_hot(obs_grid, num_obj_types)
1150
- # Agent position
1151
- obs = obs.at[state.pos[1], state.pos[0], :].set(0)
1152
- obs = obs.at[state.pos[1], state.pos[0], -1].set(1)
1153
- colors = self.object_colors / 255.0
1154
- obs = jnp.tensordot(obs, colors, axes=1)
1159
+ # Use state colors directly (supports dynamic biomes)
1160
+ colors = state.object_state.color / 255.0
1161
+
1162
+ # Mask empty cells (object_id == 0) to white
1163
+ empty_mask = obs_grid == 0
1164
+ white_color = jnp.ones((self.size[1], self.size[0], 3), dtype=jnp.float32)
1165
+ obs = jnp.where(empty_mask[..., None], white_color, colors)
1166
+
1155
1167
  return obs
1156
1168
  elif self.observation_type == "color":
1157
1169
  # Handle case with no objects (only EMPTY)
@@ -1168,17 +1180,24 @@ class ForagaxEnv(environment.Environment):
1168
1180
  else:
1169
1181
  raise ValueError(f"Unknown observation_type: {self.observation_type}")
1170
1182
 
1171
- def _get_aperture_obs(self, aperture: jax.Array, state: EnvState) -> jax.Array:
1183
+ def _get_aperture_obs(
1184
+ self, aperture: jax.Array, color_aperture: jax.Array, state: EnvState
1185
+ ) -> jax.Array:
1172
1186
  """Get aperture observation."""
1173
1187
  if self.observation_type == "object":
1174
1188
  num_obj_types = len(self.object_ids)
1175
1189
  obs = jax.nn.one_hot(aperture, num_obj_types, axis=-1)
1176
1190
  return obs
1177
1191
  elif self.observation_type == "rgb":
1178
- num_obj_types = len(self.object_ids)
1179
- aperture_one_hot = jax.nn.one_hot(aperture, num_obj_types)
1180
- colors = self.object_colors / 255.0
1181
- obs = jnp.tensordot(aperture_one_hot, colors, axes=1)
1192
+ # Use the color aperture that was passed in
1193
+ aperture_colors = color_aperture / 255.0
1194
+
1195
+ # Mask empty cells (object_id == 0) to white
1196
+ empty_mask = aperture == 0
1197
+ white_color = jnp.ones(aperture_colors.shape, dtype=jnp.float32)
1198
+
1199
+ obs = jnp.where(empty_mask[..., None], white_color, aperture_colors)
1200
+
1182
1201
  return obs
1183
1202
  elif self.observation_type == "color":
1184
1203
  # Handle case with no objects (only EMPTY)
@@ -1229,6 +1248,10 @@ class ForagaxEnv(environment.Environment):
1229
1248
  if self.dynamic_biomes:
1230
1249
  # Use per-instance colors from state
1231
1250
  img = state.object_state.color.copy()
1251
+ # Mask empty cells (object_id == 0) to white
1252
+ empty_mask = state.object_state.object_id == 0
1253
+ white_color = jnp.array([255, 255, 255], dtype=jnp.uint8)
1254
+ img = jnp.where(empty_mask[..., None], white_color, img)
1232
1255
  else:
1233
1256
  # Use default object colors
1234
1257
  img = jnp.zeros((self.size[1], self.size[0], 3))
@@ -1297,6 +1320,14 @@ class ForagaxEnv(environment.Environment):
1297
1320
  )
1298
1321
  img = state.object_state.color[y_coords_adj, x_coords_adj]
1299
1322
 
1323
+ # Mask empty cells (object_id == 0) to white
1324
+ aperture_object_ids = state.object_state.object_id[
1325
+ y_coords_adj, x_coords_adj
1326
+ ]
1327
+ empty_mask = aperture_object_ids == 0
1328
+ white_color = jnp.array([255, 255, 255], dtype=jnp.uint8)
1329
+ img = jnp.where(empty_mask[..., None], white_color, img)
1330
+
1300
1331
  if self.nowrap:
1301
1332
  # For out-of-bounds, use padding object color
1302
1333
  y_out = (y_coords < 0) | (y_coords >= self.size[1])