continual-foragax 0.32.1__py3-none-any.whl → 0.33.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: continual-foragax
3
- Version: 0.32.1
3
+ Version: 0.33.1
4
4
  Summary: A continual reinforcement learning benchmark
5
5
  Author-email: Steven Tang <stang5@ualberta.ca>
6
6
  Requires-Python: >=3.8
@@ -1,6 +1,6 @@
1
1
  foragax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  foragax/colors.py,sha256=rqNPiywP4Nvr0POhsGpasRk-nMMTS3DOwFRUgperlUk,2065
3
- foragax/env.py,sha256=Jjo7XypfQf6ePoKYTV3xvolK4qpacuaifWLZB0ke5y8,54559
3
+ foragax/env.py,sha256=bDhNQTaqcoOwm9Csb1LHoduuNdE1j1RAhGnVV7cAEPI,55147
4
4
  foragax/objects.py,sha256=9wv0ZKT89dDkaeVwUwkVo4dwhRVeUxvsTyhoyYKfOEw,26508
5
5
  foragax/registry.py,sha256=hfzQHNgX6uoOdbf4_21iH25abQVQZIjBWn7h5bdrSBg,17981
6
6
  foragax/rendering.py,sha256=bms7wvBZTofoR-K-2QD2Ggeed7Viw8uwAEiEpEM3eSo,2768
@@ -128,8 +128,8 @@ foragax/data/ECA_non-blended_custom/TG_SOUID156887.txt,sha256=juzTPgJoJxfqmZkorL
128
128
  foragax/data/ECA_non-blended_custom/elements.txt,sha256=OtcUBoDAHxuln79BPKGu0tsQxG_5G2BfAX3Ck130kEA,4507
129
129
  foragax/data/ECA_non-blended_custom/metadata.txt,sha256=nudnmOCy5cPJfSXt_IjyX0S5-T7NkCZREICZSimqeqc,48260
130
130
  foragax/data/ECA_non-blended_custom/sources.txt,sha256=1j3lSmINAoCMqPqFrHfZJriOz6sTYZNOhXzUwvTLas0,20857
131
- continual_foragax-0.32.1.dist-info/METADATA,sha256=ZG39JPQKbUW7ag-vTZtcDfL8Wvt-nCfO-KOCOZMgOIo,4713
132
- continual_foragax-0.32.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
133
- continual_foragax-0.32.1.dist-info/entry_points.txt,sha256=Qiu6iE_XudrDO_bVAMeA435h4PO9ourt8huvSHiuMPc,41
134
- continual_foragax-0.32.1.dist-info/top_level.txt,sha256=-z3SDK6RfLIcLI24n8rdbeFzlVY3hunChzlu-v1Fncs,8
135
- continual_foragax-0.32.1.dist-info/RECORD,,
131
+ continual_foragax-0.33.1.dist-info/METADATA,sha256=n4TeSbm2f0oBPHWnRPozN-Y76qn3RlF9ey2QI8jtVU0,4713
132
+ continual_foragax-0.33.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
133
+ continual_foragax-0.33.1.dist-info/entry_points.txt,sha256=Qiu6iE_XudrDO_bVAMeA435h4PO9ourt8huvSHiuMPc,41
134
+ continual_foragax-0.33.1.dist-info/top_level.txt,sha256=-z3SDK6RfLIcLI24n8rdbeFzlVY3hunChzlu-v1Fncs,8
135
+ continual_foragax-0.33.1.dist-info/RECORD,,
foragax/env.py CHANGED
@@ -562,6 +562,23 @@ class ForagaxEnv(environment.Environment):
562
562
  biome_state=biome_state,
563
563
  )
564
564
 
565
+ # Compute reward at each grid position
566
+ fixed_key = jax.random.key(0) # Fixed key for deterministic reward computation
567
+
568
+ def compute_reward(obj_id, params):
569
+ return jax.lax.cond(
570
+ obj_id > 0,
571
+ lambda: jax.lax.switch(
572
+ obj_id, self.reward_fns, state.time, fixed_key, params
573
+ ),
574
+ lambda: 0.0,
575
+ )
576
+
577
+ reward_grid = jax.vmap(jax.vmap(compute_reward))(
578
+ object_state.object_id, object_state.state_params
579
+ )
580
+ info["rewards"] = reward_grid
581
+
565
582
  done = self.is_terminal(state, params)
566
583
  return (
567
584
  jax.lax.stop_gradient(self.get_obs(state, params)),