continual-foragax 0.10.0__py3-none-any.whl → 0.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: continual-foragax
3
- Version: 0.10.0
3
+ Version: 0.10.1
4
4
  Summary: A continual reinforcement learning benchmark
5
5
  Author-email: Steven Tang <stang5@ualberta.ca>
6
6
  Requires-Python: >=3.8
@@ -1,6 +1,6 @@
1
1
  foragax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  foragax/colors.py,sha256=rqNPiywP4Nvr0POhsGpasRk-nMMTS3DOwFRUgperlUk,2065
3
- foragax/env.py,sha256=OtpcyqzBOQLdTvvRegD3SYm4mi4Ga2WE5eJ7OQmQOaw,18294
3
+ foragax/env.py,sha256=X9oc60xNL4uTFnbt_BynN_c3XmVa9MYbSclW-g4qQoc,18628
4
4
  foragax/objects.py,sha256=CyBxrykTxpHCI_2hE9jE8mG4TU8R7VxzKdQ5mtxkEqU,6004
5
5
  foragax/registry.py,sha256=7_RDXvm_3RNO7culBLGkE0jH8Wk_q6jbMv72dZx4JO8,2722
6
6
  foragax/rendering.py,sha256=KAoQpdndy5JDQlyG0c5QDHuH-_Tfy5RuVlDtndnHVjc,2765
@@ -128,8 +128,8 @@ foragax/data/ECA_non-blended_custom/TG_SOUID156887.txt,sha256=juzTPgJoJxfqmZkorL
128
128
  foragax/data/ECA_non-blended_custom/elements.txt,sha256=OtcUBoDAHxuln79BPKGu0tsQxG_5G2BfAX3Ck130kEA,4507
129
129
  foragax/data/ECA_non-blended_custom/metadata.txt,sha256=nudnmOCy5cPJfSXt_IjyX0S5-T7NkCZREICZSimqeqc,48260
130
130
  foragax/data/ECA_non-blended_custom/sources.txt,sha256=1j3lSmINAoCMqPqFrHfZJriOz6sTYZNOhXzUwvTLas0,20857
131
- continual_foragax-0.10.0.dist-info/METADATA,sha256=Oo3pJnjoU7VLruPXrhNT-WfzF7cGvkGN8crW-BclMsk,4897
132
- continual_foragax-0.10.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
133
- continual_foragax-0.10.0.dist-info/entry_points.txt,sha256=Qiu6iE_XudrDO_bVAMeA435h4PO9ourt8huvSHiuMPc,41
134
- continual_foragax-0.10.0.dist-info/top_level.txt,sha256=-z3SDK6RfLIcLI24n8rdbeFzlVY3hunChzlu-v1Fncs,8
135
- continual_foragax-0.10.0.dist-info/RECORD,,
131
+ continual_foragax-0.10.1.dist-info/METADATA,sha256=tgSSa5FtfA6qRBCZ_L3r3XnsfDpjhKMTRFJ9TL_OfCo,4897
132
+ continual_foragax-0.10.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
133
+ continual_foragax-0.10.1.dist-info/entry_points.txt,sha256=Qiu6iE_XudrDO_bVAMeA435h4PO9ourt8huvSHiuMPc,41
134
+ continual_foragax-0.10.1.dist-info/top_level.txt,sha256=-z3SDK6RfLIcLI24n8rdbeFzlVY3hunChzlu-v1Fncs,8
135
+ continual_foragax-0.10.1.dist-info/RECORD,,
foragax/env.py CHANGED
@@ -492,7 +492,14 @@ class ForagaxWorldEnv(ForagaxEnv):
492
492
  num_obj_types = len(self.object_ids)
493
493
  # Decode grid for observation
494
494
  obs_grid = jnp.maximum(0, state.object_grid)
495
- obs = jax.nn.one_hot(obs_grid, num_obj_types)
495
+ obs = jnp.zeros((self.size[1], self.size[0], num_obj_types), dtype=jnp.float32)
496
+
497
+ num_object_channels = num_obj_types - 1
498
+ # create masks for all objects at once
499
+ object_ids = jnp.arange(1, num_obj_types)
500
+ object_masks = obs_grid[..., None] == object_ids[None, None, :]
501
+ obs = obs.at[:, :, :num_object_channels].set(object_masks.astype(float))
502
+
496
503
  obs = obs.at[state.pos[1], state.pos[0], -1].set(1)
497
504
  return obs
498
505