continual-foragax 0.42.1__tar.gz → 0.42.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (145) hide show
  1. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/PKG-INFO +1 -1
  2. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/pyproject.toml +2 -2
  3. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/continual_foragax.egg-info/PKG-INFO +1 -1
  4. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/env.py +118 -163
  5. continual_foragax-0.42.2/src/foragax/rendering.py +171 -0
  6. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/tests/test_foragax.py +63 -0
  7. continual_foragax-0.42.1/src/foragax/rendering.py +0 -53
  8. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/README.md +0 -0
  9. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/setup.cfg +0 -0
  10. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/continual_foragax.egg-info/SOURCES.txt +0 -0
  11. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/continual_foragax.egg-info/dependency_links.txt +0 -0
  12. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/continual_foragax.egg-info/entry_points.txt +0 -0
  13. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/continual_foragax.egg-info/requires.txt +0 -0
  14. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/continual_foragax.egg-info/top_level.txt +0 -0
  15. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/__init__.py +0 -0
  16. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/colors.py +0 -0
  17. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID100897.txt +0 -0
  18. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID100928.txt +0 -0
  19. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID100929.txt +0 -0
  20. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID100930.txt +0 -0
  21. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID100931.txt +0 -0
  22. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106714.txt +0 -0
  23. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106715.txt +0 -0
  24. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106716.txt +0 -0
  25. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106717.txt +0 -0
  26. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106718.txt +0 -0
  27. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106930.txt +0 -0
  28. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106931.txt +0 -0
  29. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106932.txt +0 -0
  30. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106933.txt +0 -0
  31. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106934.txt +0 -0
  32. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106935.txt +0 -0
  33. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106936.txt +0 -0
  34. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106937.txt +0 -0
  35. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106938.txt +0 -0
  36. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106939.txt +0 -0
  37. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106940.txt +0 -0
  38. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106941.txt +0 -0
  39. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106942.txt +0 -0
  40. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106943.txt +0 -0
  41. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106994.txt +0 -0
  42. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106995.txt +0 -0
  43. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106996.txt +0 -0
  44. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106997.txt +0 -0
  45. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106998.txt +0 -0
  46. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106999.txt +0 -0
  47. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107000.txt +0 -0
  48. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107001.txt +0 -0
  49. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107002.txt +0 -0
  50. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107003.txt +0 -0
  51. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107004.txt +0 -0
  52. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107005.txt +0 -0
  53. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107006.txt +0 -0
  54. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107007.txt +0 -0
  55. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107008.txt +0 -0
  56. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107009.txt +0 -0
  57. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107010.txt +0 -0
  58. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107011.txt +0 -0
  59. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107012.txt +0 -0
  60. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107013.txt +0 -0
  61. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107014.txt +0 -0
  62. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107015.txt +0 -0
  63. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107016.txt +0 -0
  64. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107017.txt +0 -0
  65. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107018.txt +0 -0
  66. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107019.txt +0 -0
  67. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107020.txt +0 -0
  68. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107021.txt +0 -0
  69. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107022.txt +0 -0
  70. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107023.txt +0 -0
  71. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107024.txt +0 -0
  72. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107025.txt +0 -0
  73. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107026.txt +0 -0
  74. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107027.txt +0 -0
  75. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107028.txt +0 -0
  76. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107029.txt +0 -0
  77. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107030.txt +0 -0
  78. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107031.txt +0 -0
  79. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107032.txt +0 -0
  80. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107033.txt +0 -0
  81. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107034.txt +0 -0
  82. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107035.txt +0 -0
  83. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107036.txt +0 -0
  84. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107037.txt +0 -0
  85. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107038.txt +0 -0
  86. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107039.txt +0 -0
  87. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107040.txt +0 -0
  88. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107041.txt +0 -0
  89. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107042.txt +0 -0
  90. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107043.txt +0 -0
  91. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107044.txt +0 -0
  92. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107045.txt +0 -0
  93. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107046.txt +0 -0
  94. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107047.txt +0 -0
  95. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107048.txt +0 -0
  96. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107049.txt +0 -0
  97. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107050.txt +0 -0
  98. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107051.txt +0 -0
  99. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107052.txt +0 -0
  100. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107053.txt +0 -0
  101. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107054.txt +0 -0
  102. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107055.txt +0 -0
  103. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107056.txt +0 -0
  104. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107057.txt +0 -0
  105. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107058.txt +0 -0
  106. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107059.txt +0 -0
  107. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107060.txt +0 -0
  108. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107061.txt +0 -0
  109. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107062.txt +0 -0
  110. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107063.txt +0 -0
  111. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107064.txt +0 -0
  112. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107065.txt +0 -0
  113. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107066.txt +0 -0
  114. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107067.txt +0 -0
  115. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107068.txt +0 -0
  116. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107069.txt +0 -0
  117. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107070.txt +0 -0
  118. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107071.txt +0 -0
  119. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID115808.txt +0 -0
  120. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID115812.txt +0 -0
  121. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID146811.txt +0 -0
  122. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID156831.txt +0 -0
  123. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID156835.txt +0 -0
  124. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID156839.txt +0 -0
  125. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID156843.txt +0 -0
  126. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID156847.txt +0 -0
  127. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID156851.txt +0 -0
  128. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID156855.txt +0 -0
  129. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID156859.txt +0 -0
  130. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID156863.txt +0 -0
  131. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID156867.txt +0 -0
  132. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID156871.txt +0 -0
  133. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID156875.txt +0 -0
  134. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID156879.txt +0 -0
  135. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID156883.txt +0 -0
  136. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID156887.txt +0 -0
  137. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/elements.txt +0 -0
  138. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/metadata.txt +0 -0
  139. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/sources.txt +0 -0
  140. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/objects.py +0 -0
  141. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/registry.py +0 -0
  142. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/src/foragax/weather.py +0 -0
  143. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/tests/test_benchmark.py +0 -0
  144. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/tests/test_optimize.py +0 -0
  145. {continual_foragax-0.42.1 → continual_foragax-0.42.2}/tests/test_registry.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: continual-foragax
3
- Version: 0.42.1
3
+ Version: 0.42.2
4
4
  Summary: A continual reinforcement learning benchmark
5
5
  Author-email: Steven Tang <stang5@ualberta.ca>
6
6
  Requires-Python: >=3.8
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "continual-foragax"
3
- version = "0.42.1"
3
+ version = "0.42.2"
4
4
  description = "A continual reinforcement learning benchmark"
5
5
  readme = "README.md"
6
6
  authors = [
@@ -34,7 +34,7 @@ build-backend = "setuptools.build_meta"
34
34
  [tool]
35
35
  [tool.commitizen]
36
36
  name = "cz_conventional_commits"
37
- version = "0.42.1"
37
+ version = "0.42.2"
38
38
  tag_format = "$version"
39
39
  version_files = ["pyproject.toml"]
40
40
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: continual-foragax
3
- Version: 0.42.1
3
+ Version: 0.42.2
4
4
  Summary: A continual reinforcement learning benchmark
5
5
  Author-email: Steven Tang <stang5@ualberta.ca>
6
6
  Requires-Python: >=3.8
@@ -22,7 +22,13 @@ from foragax.objects import (
22
22
  FourierObject,
23
23
  WeatherObject,
24
24
  )
25
- from foragax.rendering import apply_true_borders
25
+ from foragax.rendering import (
26
+ apply_grid_lines,
27
+ apply_reward_overlay,
28
+ apply_true_borders,
29
+ get_base_image,
30
+ reward_to_color,
31
+ )
26
32
  from foragax.weather import get_temperature
27
33
 
28
34
 
@@ -364,6 +370,19 @@ class ForagaxEnv(environment.Environment):
364
370
  jnp.array(0, dtype=ID_DTYPE)
365
371
  )
366
372
 
373
+ # Extract the actual state to move
374
+ obj_color = object_state.color[y, x]
375
+ obj_params = object_state.state_params[y, x]
376
+ obj_gen = object_state.generation[y, x]
377
+
378
+ # Clear visuals at old position
379
+ new_color = object_state.color.at[y, x].set(
380
+ jnp.zeros(3, dtype=COLOR_DTYPE)
381
+ )
382
+ new_params = object_state.state_params.at[y, x].set(
383
+ jnp.zeros_like(obj_params)
384
+ )
385
+
367
386
  # Find valid spawn locations in the same biome
368
387
  biome_id = object_state.biome_id[y, x]
369
388
  biome_mask = object_state.biome_id == biome_id
@@ -381,7 +400,7 @@ class ForagaxEnv(environment.Environment):
381
400
  )
382
401
  new_spawn_pos = valid_spawn_indices[random_idx]
383
402
 
384
- # Place timer at the new random position
403
+ # Place timer and move properties at the new random position
385
404
  new_respawn_timer = new_respawn_timer.at[
386
405
  new_spawn_pos[0], new_spawn_pos[1]
387
406
  ].set(timer_val)
@@ -389,10 +408,24 @@ class ForagaxEnv(environment.Environment):
389
408
  new_spawn_pos[0], new_spawn_pos[1]
390
409
  ].set(object_type)
391
410
 
411
+ # Move properties to new position
412
+ new_color = new_color.at[new_spawn_pos[0], new_spawn_pos[1]].set(
413
+ obj_color
414
+ )
415
+ new_params = new_params.at[new_spawn_pos[0], new_spawn_pos[1]].set(
416
+ obj_params
417
+ )
418
+ new_generation = object_state.generation.at[
419
+ new_spawn_pos[0], new_spawn_pos[1]
420
+ ].set(obj_gen)
421
+
392
422
  return object_state.replace(
393
423
  object_id=new_object_id,
394
424
  respawn_timer=new_respawn_timer,
395
425
  respawn_object_id=new_respawn_object_id,
426
+ color=new_color,
427
+ state_params=new_params,
428
+ generation=new_generation,
396
429
  )
397
430
 
398
431
  return jax.lax.cond(random_respawn, place_randomly, place_at_position)
@@ -604,22 +637,22 @@ class ForagaxEnv(environment.Environment):
604
637
  # Compute reward at each grid position
605
638
  fixed_key = jax.random.key(0) # Fixed key for deterministic reward computation
606
639
 
607
- def compute_reward(obj_id, params):
608
- return jax.lax.cond(
609
- obj_id > jnp.array(0, dtype=ID_DTYPE),
610
- lambda: jax.lax.switch(
611
- obj_id.astype(jnp.int32),
612
- self.reward_fns,
613
- state.time,
614
- fixed_key,
615
- params.astype(jnp.float32),
616
- ),
617
- lambda: 0.0,
640
+ def compute_reward(obj_id, params, timer):
641
+ reward = jax.lax.switch(
642
+ obj_id.astype(jnp.int32),
643
+ self.reward_fns,
644
+ state.time,
645
+ fixed_key,
646
+ params.astype(jnp.float32),
618
647
  )
648
+ # Only show reward for objects that are fully present (no timer)
649
+ mask = (obj_id > 0) & (timer == 0)
650
+ return jnp.where(mask, reward, 0.0)
619
651
 
620
652
  reward_grid = jax.vmap(jax.vmap(compute_reward))(
621
653
  object_state.object_id.astype(ID_DTYPE),
622
654
  object_state.state_params.astype(PARAM_DTYPE),
655
+ object_state.respawn_timer.astype(TIMER_DTYPE),
623
656
  )
624
657
  return reward_grid
625
658
 
@@ -1465,68 +1498,34 @@ class ForagaxEnv(environment.Environment):
1465
1498
  return spaces.Box(0, 1, obs_shape, float)
1466
1499
 
1467
1500
  def _compute_reward_grid(
1468
- self, state: EnvState, object_id=None, state_params=None
1501
+ self, state: EnvState, object_id=None, state_params=None, respawn_timer=None
1469
1502
  ) -> jax.Array:
1470
1503
  """Compute rewards for given positions. If no grid provided, uses full world."""
1471
1504
  if object_id is None:
1472
1505
  object_id = state.object_state.object_id
1473
1506
  if state_params is None:
1474
1507
  state_params = state.object_state.state_params
1508
+ if respawn_timer is None:
1509
+ respawn_timer = state.object_state.respawn_timer
1475
1510
 
1476
1511
  fixed_key = jax.random.key(0) # Fixed key for deterministic reward computation
1477
1512
 
1478
- def compute_reward(obj_id, params):
1479
- return jax.lax.cond(
1480
- obj_id > 0,
1481
- lambda: jax.lax.switch(
1482
- obj_id, self.reward_fns, state.time, fixed_key, params
1483
- ),
1484
- lambda: 0.0,
1513
+ def compute_reward(obj_id, params, timer):
1514
+ reward = jax.lax.switch(
1515
+ obj_id.astype(jnp.int32),
1516
+ self.reward_fns,
1517
+ state.time,
1518
+ fixed_key,
1519
+ params.astype(jnp.float32),
1485
1520
  )
1521
+ # Only show reward for objects that are fully present (no timer)
1522
+ mask = (obj_id > 0) & (timer == 0)
1523
+ return jnp.where(mask, reward, 0.0)
1486
1524
 
1487
- reward_grid = jax.vmap(jax.vmap(compute_reward))(object_id, state_params)
1488
- return reward_grid
1489
-
1490
- def _reward_to_color(self, reward: jax.Array) -> jax.Array:
1491
- """Convert reward value to RGB color using diverging gradient.
1492
-
1493
- Args:
1494
- reward: Reward value (typically -1 to +1)
1495
-
1496
- Returns:
1497
- RGB color array with shape (..., 3) and dtype uint8
1498
- """
1499
- # Diverging gradient: +1 = green (0, 255, 0), 0 = white (255, 255, 255), -1 = magenta (255, 0, 255)
1500
- # Clamp reward to [-1, 1] range for color mapping
1501
- reward_clamped = jnp.clip(reward, -1.0, 1.0)
1502
-
1503
- # For positive rewards: interpolate from white to green
1504
- # For negative rewards: interpolate from white to magenta
1505
- # At reward = 0: white (255, 255, 255)
1506
- # At reward = +1: green (0, 255, 0)
1507
- # At reward = -1: magenta (255, 0, 255)
1508
-
1509
- red_component = jnp.where(
1510
- reward_clamped >= 0,
1511
- (1 - reward_clamped) * 255, # Fade from white to green: 255 -> 0
1512
- 255, # Stay at 255 for all negative rewards
1513
- )
1514
-
1515
- green_component = jnp.where(
1516
- reward_clamped >= 0,
1517
- 255, # Stay at 255 for all positive rewards
1518
- (1 + reward_clamped) * 255, # Fade from white to magenta: 255 -> 0
1519
- )
1520
-
1521
- blue_component = jnp.where(
1522
- reward_clamped >= 0,
1523
- (1 - reward_clamped) * 255, # Fade from white to green: 255 -> 0
1524
- 255, # Stay at 255 for all negative rewards
1525
+ reward_grid = jax.vmap(jax.vmap(compute_reward))(
1526
+ object_id, state_params, respawn_timer
1525
1527
  )
1526
-
1527
- return jnp.stack(
1528
- [red_component, green_component, blue_component], axis=-1
1529
- ).astype(jnp.uint8)
1528
+ return reward_grid
1530
1529
 
1531
1530
  @partial(jax.jit, static_argnames=("self", "render_mode"))
1532
1531
  def render(
@@ -1535,13 +1534,7 @@ class ForagaxEnv(environment.Environment):
1535
1534
  params: EnvParams,
1536
1535
  render_mode: str = "world",
1537
1536
  ):
1538
- """Render the environment state.
1539
-
1540
- Args:
1541
- state: Current environment state
1542
- params: Environment parameters
1543
- render_mode: One of "world", "world_true", "world_reward", "aperture", "aperture_true", "aperture_reward"
1544
- """
1537
+ """Render the environment state."""
1545
1538
  is_world_mode = render_mode in ("world", "world_true", "world_reward")
1546
1539
  is_aperture_mode = render_mode in (
1547
1540
  "aperture",
@@ -1552,27 +1545,12 @@ class ForagaxEnv(environment.Environment):
1552
1545
  is_reward_mode = render_mode in ("world_reward", "aperture_reward")
1553
1546
 
1554
1547
  if is_world_mode:
1555
- # Create an RGB image from the object grid
1556
- # Use stateful object colors if dynamic_biomes is enabled, else use default colors
1557
- if self.dynamic_biomes:
1558
- # Use per-instance colors from state
1559
- img = state.object_state.color.copy()
1560
- # Mask empty cells (object_id == 0) to white
1561
- empty_mask = state.object_state.object_id == 0
1562
- white_color = jnp.array([255, 255, 255], dtype=jnp.uint8)
1563
- img = jnp.where(empty_mask[..., None], white_color, img)
1564
- else:
1565
- # Use default object colors
1566
- img = jnp.zeros((self.size[1], self.size[0], 3))
1567
- render_grid = state.object_state.object_id
1568
-
1569
- def update_image(i, img):
1570
- color = self.object_colors[i]
1571
- mask = render_grid == i
1572
- img = jnp.where(mask[..., None], color, img)
1573
- return img
1574
-
1575
- img = jax.lax.fori_loop(0, len(self.object_ids), update_image, img)
1548
+ img = get_base_image(
1549
+ state.object_state.object_id,
1550
+ state.object_state.color,
1551
+ self.object_colors,
1552
+ self.dynamic_biomes,
1553
+ )
1576
1554
 
1577
1555
  # Define constants for all world modes
1578
1556
  alpha = 0.2
@@ -1582,26 +1560,18 @@ class ForagaxEnv(environment.Environment):
1582
1560
 
1583
1561
  if is_reward_mode:
1584
1562
  # Construct 3x intermediate image
1585
- # Each cell is 3x3, with reward color in center
1586
1563
  reward_grid = self._compute_reward_grid(state)
1587
- reward_colors = self._reward_to_color(reward_grid)
1588
-
1589
- # Each cell has its base color in 8 pixels and reward color in 1 (center)
1590
- # Create a 3x3 pattern mask for center pixels
1591
- cell_mask = jnp.array(
1592
- [[False, False, False], [False, True, False], [False, False, False]]
1593
- )
1594
- grid_reward_mask = jnp.tile(cell_mask, (self.size[1], self.size[0]))
1564
+ reward_colors = reward_to_color(reward_grid)
1595
1565
 
1596
- # Repeat base colors and rewards to 3x3
1566
+ # Repeat base colors to 3x scale
1597
1567
  base_img_x3 = jnp.repeat(jnp.repeat(img, 3, axis=0), 3, axis=1)
1598
- reward_colors_x3 = jnp.repeat(
1599
- jnp.repeat(reward_colors, 3, axis=0), 3, axis=1
1600
- )
1601
1568
 
1602
- # Composite base and reward colors
1603
- img = jnp.where(
1604
- grid_reward_mask[..., None], reward_colors_x3, base_img_x3
1569
+ # Composite base and reward colors using helper
1570
+ img = apply_reward_overlay(
1571
+ base_img_x3,
1572
+ reward_colors,
1573
+ reward_grid,
1574
+ self.size,
1605
1575
  )
1606
1576
 
1607
1577
  # Tint the aperture region at 3x scale
@@ -1647,83 +1617,72 @@ class ForagaxEnv(environment.Environment):
1647
1617
  img, state.object_state.object_id, self.size, len(self.object_ids)
1648
1618
  )
1649
1619
 
1650
- # Add grid lines using masking instead of slice-setting
1651
- row_grid = (jnp.arange(self.size[1] * 24) % 24) == 0
1652
- col_grid = (jnp.arange(self.size[0] * 24) % 24) == 0
1653
- # skip first rows/cols as they are borders or managed by caller
1654
- row_grid = row_grid.at[0].set(False)
1655
- col_grid = col_grid.at[0].set(False)
1656
- grid_mask = row_grid[:, None] | col_grid[None, :]
1657
- img = jnp.where(grid_mask[..., None], self.grid_color_jax, img)
1620
+ # Add grid lines
1621
+ img = apply_grid_lines(img, self.size, self.grid_color_jax)
1658
1622
 
1659
1623
  elif is_aperture_mode:
1660
1624
  obs_grid = state.object_state.object_id
1661
1625
  aperture = self._get_aperture(obs_grid, state.pos)
1662
1626
 
1663
- if self.dynamic_biomes:
1664
- # Use per-instance colors from state - extract aperture view
1665
- y_coords, x_coords, y_coords_adj, x_coords_adj = (
1666
- self._compute_aperture_coordinates(state.pos)
1667
- )
1668
- img = state.object_state.color[y_coords_adj, x_coords_adj]
1627
+ y_coords, x_coords, y_coords_adj, x_coords_adj = (
1628
+ self._compute_aperture_coordinates(state.pos)
1629
+ )
1630
+ color_state = state.object_state.color[y_coords_adj, x_coords_adj]
1669
1631
 
1670
- # Mask empty cells (object_id == 0) to white
1671
- aperture_object_ids = state.object_state.object_id[
1672
- y_coords_adj, x_coords_adj
1673
- ]
1674
- empty_mask = aperture_object_ids == 0
1675
- white_color = jnp.array([255, 255, 255], dtype=jnp.uint8)
1676
- img = jnp.where(empty_mask[..., None], white_color, img)
1677
-
1678
- if self.nowrap:
1679
- # For out-of-bounds, use padding object color
1680
- y_out = (y_coords < 0) | (y_coords >= self.size[1])
1681
- x_out = (x_coords < 0) | (x_coords >= self.size[0])
1682
- out_of_bounds = y_out | x_out
1683
- padding_color = jnp.array(self.objects[-1].color, dtype=jnp.float32)
1684
- img = jnp.where(out_of_bounds[..., None], padding_color, img)
1685
- else:
1686
- # Use default object colors
1687
- aperture_one_hot = jax.nn.one_hot(aperture, len(self.object_ids))
1688
- img = jnp.tensordot(aperture_one_hot, self.object_colors, axes=1)
1632
+ img = get_base_image(
1633
+ aperture,
1634
+ color_state,
1635
+ self.object_colors,
1636
+ self.dynamic_biomes,
1637
+ )
1638
+
1639
+ if self.dynamic_biomes and self.nowrap:
1640
+ # For out-of-bounds, use padding object color
1641
+ y_out = (y_coords < 0) | (y_coords >= self.size[1])
1642
+ x_out = (x_coords < 0) | (x_coords >= self.size[0])
1643
+ out_of_bounds = y_out | x_out
1644
+ padding_color = jnp.array(self.objects[-1].color, dtype=jnp.float32)
1645
+ img = jnp.where(out_of_bounds[..., None], padding_color, img)
1689
1646
 
1690
1647
  if is_reward_mode:
1691
1648
  # Scale image by 3 to create space for reward visualization
1692
1649
  img = img.astype(jnp.uint8)
1693
1650
  img = jax.image.resize(
1694
1651
  img,
1695
- (self.aperture_size[0] * 3, self.aperture_size[1] * 3, 3),
1652
+ (
1653
+ self.aperture_size[0] * 3,
1654
+ self.aperture_size[1] * 3,
1655
+ 3,
1656
+ ),
1696
1657
  jax.image.ResizeMethod.NEAREST,
1697
1658
  )
1698
1659
 
1699
1660
  # Compute rewards for aperture region
1700
- y_coords, x_coords, y_coords_adj, x_coords_adj = (
1701
- self._compute_aperture_coordinates(state.pos)
1702
- )
1703
-
1704
- # Get reward grid only for aperture region
1705
- aperture_object_ids = state.object_state.object_id[
1706
- y_coords_adj, x_coords_adj
1707
- ]
1708
1661
  aperture_params = state.object_state.state_params[
1709
1662
  y_coords_adj, x_coords_adj
1710
1663
  ]
1664
+ aperture_timer = self._get_aperture(
1665
+ state.object_state.respawn_timer, state.pos
1666
+ )
1711
1667
  aperture_rewards = self._compute_reward_grid(
1712
- state, aperture_object_ids, aperture_params
1668
+ state, aperture, aperture_params, aperture_timer
1713
1669
  )
1714
1670
 
1715
1671
  # Convert rewards to colors
1716
- reward_colors = self._reward_to_color(aperture_rewards)
1672
+ reward_colors = reward_to_color(aperture_rewards)
1717
1673
 
1718
- # Place reward colors in the middle cells (index 1 in each 3x3 block)
1719
- i_indices = jnp.arange(self.aperture_size[0])[:, None] * 3 + 1
1720
- j_indices = jnp.arange(self.aperture_size[1])[None, :] * 3 + 1
1721
- img = img.at[i_indices, j_indices].set(reward_colors)
1674
+ # Apply reward overlay using helper
1675
+ img = apply_reward_overlay(
1676
+ img,
1677
+ reward_colors,
1678
+ aperture_rewards,
1679
+ self.aperture_size,
1680
+ )
1722
1681
 
1723
1682
  # Draw agent in the center (all 9 cells of the 3x3 block)
1724
1683
  center_y, center_x = (
1725
- self.aperture_size[1] // 2,
1726
1684
  self.aperture_size[0] // 2,
1685
+ self.aperture_size[1] // 2,
1727
1686
  )
1728
1687
  agent_offsets = jnp.array(
1729
1688
  [[dy, dx] for dy in range(3) for dx in range(3)]
@@ -1757,17 +1716,13 @@ class ForagaxEnv(environment.Environment):
1757
1716
  )
1758
1717
 
1759
1718
  if is_true_mode:
1760
- # Apply true object borders by overlaying true colors on border pixels
1719
+ # Apply true object borders
1761
1720
  img = apply_true_borders(
1762
1721
  img, aperture, self.aperture_size, len(self.object_ids)
1763
1722
  )
1764
1723
 
1765
- # Add grid lines for aperture mode
1766
- grid_color = jnp.zeros(3, dtype=jnp.uint8)
1767
- row_indices = jnp.arange(1, self.aperture_size[0]) * 24
1768
- col_indices = jnp.arange(1, self.aperture_size[1]) * 24
1769
- img = img.at[row_indices, :].set(grid_color)
1770
- img = img.at[:, col_indices].set(grid_color)
1724
+ # Add grid lines
1725
+ img = apply_grid_lines(img, self.aperture_size, self.grid_color_jax)
1771
1726
 
1772
1727
  else:
1773
1728
  raise ValueError(f"Unknown render_mode: {render_mode}")
@@ -0,0 +1,171 @@
1
+ """Rendering utilities for Foragax environments."""
2
+
3
+ from typing import Tuple
4
+
5
+ import jax
6
+ import jax.numpy as jnp
7
+
8
+ from foragax.colors import hsv_to_rgb_255
9
+
10
+
11
+ def apply_true_borders(
12
+ base_img: jax.Array,
13
+ true_grid: jax.Array,
14
+ grid_size: Tuple[int, int],
15
+ num_objects: int,
16
+ ) -> jax.Array:
17
+ """Apply true object borders by overlaying HSV border colors on border pixels.
18
+
19
+ Args:
20
+ base_img: Base image with object colors
21
+ true_grid: Grid of object IDs for determining border colors
22
+ grid_size: (height, width) of the grid
23
+ num_objects: Number of object types
24
+
25
+ Returns:
26
+ Image with HSV borders overlaid on border pixels
27
+ """
28
+ # Create HSV border colors for each object type
29
+ hues = jnp.linspace(0, 1, num_objects, endpoint=False)
30
+
31
+ # Convert HSV to RGB for border colors
32
+ border_colors = hsv_to_rgb_255(hues[true_grid])
33
+
34
+ # Resize border colors to match rendered image size
35
+ border_img = jax.image.resize(
36
+ border_colors,
37
+ (grid_size[0] * 24, grid_size[1] * 24, 3),
38
+ jax.image.ResizeMethod.NEAREST,
39
+ )
40
+
41
+ # Create border mask (2-pixel thick borders) using vectorized modulo operations
42
+ img_height, img_width = grid_size[0] * 24, grid_size[1] * 24
43
+ y_idx = jnp.arange(img_height) % 24
44
+ x_idx = jnp.arange(img_width) % 24
45
+
46
+ # Border pixels are those with offset 0, 1, 22, or 23 within each 24x24 cell
47
+ is_border_row = (y_idx < 2) | (y_idx >= 22)
48
+ is_border_col = (x_idx < 2) | (x_idx >= 22)
49
+ border_mask = is_border_row[:, None] | is_border_col[None, :]
50
+
51
+ # Apply border mask: use HSV border colors for border pixels, base colors elsewhere
52
+ result_img = jnp.where(border_mask[..., None], border_img, base_img)
53
+ return result_img
54
+
55
+
56
+ def reward_to_color(reward: jax.Array) -> jax.Array:
57
+ """Convert reward value to RGB color using diverging gradient.
58
+
59
+ Args:
60
+ reward: Reward value (typically -1 to +1)
61
+
62
+ Returns:
63
+ RGB color array with shape (..., 3) and dtype uint8
64
+ """
65
+ # Diverging gradient: +1 = green (0, 255, 0), 0 = white (255, 255, 255), -1 = magenta (255, 0, 255)
66
+ # Clamp reward to [-1, 1] range for color mapping
67
+ reward_clamped = jnp.clip(reward, -1.0, 1.0)
68
+
69
+ # For positive rewards: interpolate from white to green
70
+ # For negative rewards: interpolate from white to magenta
71
+ # At reward = 0: white (255, 255, 255)
72
+ # At reward = +1: green (0, 255, 0)
73
+ # At reward = -1: magenta (255, 0, 255)
74
+
75
+ red_component = jnp.where(
76
+ reward_clamped >= 0,
77
+ (1 - reward_clamped) * 255, # Fade from white to green: 255 -> 0
78
+ 255, # Stay at 255 for all negative rewards
79
+ )
80
+
81
+ green_component = jnp.where(
82
+ reward_clamped >= 0,
83
+ 255, # Stay at 255 for all positive rewards
84
+ (1 + reward_clamped) * 255, # Fade from white to magenta: 255 -> 0
85
+ )
86
+
87
+ blue_component = jnp.where(
88
+ reward_clamped >= 0,
89
+ (1 - reward_clamped) * 255, # Fade from white to green: 255 -> 0
90
+ 255, # Stay at 255 for all negative rewards
91
+ )
92
+
93
+ return jnp.stack([red_component, green_component, blue_component], axis=-1).astype(
94
+ jnp.uint8
95
+ )
96
+
97
+
98
+ def get_base_image(
99
+ object_id: jax.Array,
100
+ color_state: jax.Array,
101
+ object_colors: jax.Array,
102
+ dynamic_biomes: bool,
103
+ ) -> jax.Array:
104
+ """Construct base RGB image from object IDs or colors."""
105
+ if dynamic_biomes:
106
+ # Use per-instance colors from state
107
+ img = color_state.copy()
108
+ # Mask empty cells (object_id == 0) to white
109
+ empty_mask = object_id == 0
110
+ white_color = jnp.array([255, 255, 255], dtype=jnp.uint8)
111
+ img = jnp.where(empty_mask[..., None], white_color, img)
112
+ else:
113
+ # Map object IDs to colors
114
+ img = object_colors[object_id]
115
+
116
+ return img.astype(jnp.uint8)
117
+
118
+
119
+ def apply_grid_lines(
120
+ img: jax.Array,
121
+ grid_size: Tuple[int, int],
122
+ grid_color: jax.Array,
123
+ cell_size: int = 24,
124
+ ) -> jax.Array:
125
+ """Apply grid lines to the image."""
126
+ row_grid = (jnp.arange(grid_size[0] * cell_size) % cell_size) == 0
127
+ col_grid = (jnp.arange(grid_size[1] * cell_size) % cell_size) == 0
128
+ # skip first rows/cols as they are borders or managed by caller
129
+ row_grid = row_grid.at[0].set(False)
130
+ col_grid = col_grid.at[0].set(False)
131
+ grid_mask = row_grid[:, None] | col_grid[None, :]
132
+ return jnp.where(grid_mask[..., None], grid_color, img)
133
+
134
+
135
+ def apply_reward_overlay(
136
+ base_img: jax.Array,
137
+ reward_colors: jax.Array,
138
+ reward_grid: jax.Array,
139
+ grid_size: Tuple[int, int],
140
+ ) -> jax.Array:
141
+ """Apply reward visualization overlay (center dots) to the image.
142
+
143
+ Only applies dots where the reward is non-zero (abs > 1e-5).
144
+
145
+ Args:
146
+ base_img: Base image at 3x scale (each cell is 3x3)
147
+ reward_colors: Array of RGB colors for rewards
148
+ reward_grid: Grid of reward values
149
+ grid_size: (height, width) of the grid
150
+
151
+ Returns:
152
+ Image with reward dots overlaid
153
+ """
154
+ # Create a 3x3 pattern mask for center pixels
155
+ cell_mask = jnp.array(
156
+ [[False, False, False], [False, True, False], [False, False, False]]
157
+ )
158
+ grid_reward_mask = jnp.tile(cell_mask, grid_size)
159
+
160
+ # Only show reward where reward is meaningfully non-zero
161
+ reward_nonzero = jnp.abs(reward_grid) > 1e-5
162
+ # Expand to 3x scale
163
+ reward_nonzero_x3 = jnp.repeat(jnp.repeat(reward_nonzero, 3, axis=0), 3, axis=1)
164
+
165
+ # Final mask: center pixel of a cell AND cell has a non-zero reward
166
+ composite_mask = grid_reward_mask & reward_nonzero_x3
167
+
168
+ # Repeat reward colors to 3x to match image scale
169
+ reward_colors_x3 = jnp.repeat(jnp.repeat(reward_colors, 3, axis=0), 3, axis=1)
170
+
171
+ return jnp.where(composite_mask[..., None], reward_colors_x3, base_img)