continual-foragax 0.42.0__tar.gz → 0.42.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/PKG-INFO +1 -1
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/pyproject.toml +6 -2
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/continual_foragax.egg-info/PKG-INFO +1 -1
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/env.py +118 -163
- continual_foragax-0.42.2/src/foragax/rendering.py +171 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/tests/test_foragax.py +63 -0
- continual_foragax-0.42.0/src/foragax/rendering.py +0 -53
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/README.md +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/setup.cfg +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/continual_foragax.egg-info/SOURCES.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/continual_foragax.egg-info/dependency_links.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/continual_foragax.egg-info/entry_points.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/continual_foragax.egg-info/requires.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/continual_foragax.egg-info/top_level.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/__init__.py +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/colors.py +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID100897.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID100928.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID100929.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID100930.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID100931.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106714.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106715.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106716.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106717.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106718.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106930.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106931.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106932.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106933.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106934.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106935.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106936.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106937.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106938.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106939.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106940.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106941.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106942.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106943.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106994.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106995.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106996.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106997.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106998.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106999.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107000.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107001.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107002.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107003.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107004.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107005.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107006.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107007.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107008.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107009.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107010.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107011.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107012.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107013.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107014.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107015.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107016.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107017.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107018.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107019.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107020.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107021.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107022.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107023.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107024.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107025.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107026.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107027.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107028.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107029.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107030.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107031.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107032.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107033.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107034.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107035.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107036.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107037.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107038.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107039.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107040.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107041.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107042.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107043.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107044.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107045.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107046.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107047.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107048.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107049.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107050.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107051.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107052.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107053.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107054.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107055.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107056.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107057.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107058.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107059.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107060.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107061.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107062.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107063.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107064.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107065.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107066.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107067.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107068.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107069.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107070.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107071.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID115808.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID115812.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID146811.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID156831.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID156835.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID156839.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID156843.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID156847.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID156851.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID156855.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID156859.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID156863.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID156867.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID156871.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID156875.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID156879.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID156883.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/TG_SOUID156887.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/elements.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/metadata.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/data/ECA_non-blended_custom/sources.txt +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/objects.py +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/registry.py +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/src/foragax/weather.py +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/tests/test_benchmark.py +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/tests/test_optimize.py +0 -0
- {continual_foragax-0.42.0 → continual_foragax-0.42.2}/tests/test_registry.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "continual-foragax"
|
|
3
|
-
version = "0.42.
|
|
3
|
+
version = "0.42.2"
|
|
4
4
|
description = "A continual reinforcement learning benchmark"
|
|
5
5
|
readme = "README.md"
|
|
6
6
|
authors = [
|
|
@@ -12,6 +12,10 @@ dependencies = [
|
|
|
12
12
|
"six; python_version < '3.10'",
|
|
13
13
|
]
|
|
14
14
|
|
|
15
|
+
[tool.uv]
|
|
16
|
+
index-strategy = "unsafe-best-match"
|
|
17
|
+
find-links = ["https://storage.googleapis.com/jax-releases/jax_releases.html"]
|
|
18
|
+
|
|
15
19
|
[dependency-groups]
|
|
16
20
|
dev = [
|
|
17
21
|
"pre-commit",
|
|
@@ -30,7 +34,7 @@ build-backend = "setuptools.build_meta"
|
|
|
30
34
|
[tool]
|
|
31
35
|
[tool.commitizen]
|
|
32
36
|
name = "cz_conventional_commits"
|
|
33
|
-
version = "0.42.
|
|
37
|
+
version = "0.42.2"
|
|
34
38
|
tag_format = "$version"
|
|
35
39
|
version_files = ["pyproject.toml"]
|
|
36
40
|
|
|
@@ -22,7 +22,13 @@ from foragax.objects import (
|
|
|
22
22
|
FourierObject,
|
|
23
23
|
WeatherObject,
|
|
24
24
|
)
|
|
25
|
-
from foragax.rendering import
|
|
25
|
+
from foragax.rendering import (
|
|
26
|
+
apply_grid_lines,
|
|
27
|
+
apply_reward_overlay,
|
|
28
|
+
apply_true_borders,
|
|
29
|
+
get_base_image,
|
|
30
|
+
reward_to_color,
|
|
31
|
+
)
|
|
26
32
|
from foragax.weather import get_temperature
|
|
27
33
|
|
|
28
34
|
|
|
@@ -364,6 +370,19 @@ class ForagaxEnv(environment.Environment):
|
|
|
364
370
|
jnp.array(0, dtype=ID_DTYPE)
|
|
365
371
|
)
|
|
366
372
|
|
|
373
|
+
# Extract the actual state to move
|
|
374
|
+
obj_color = object_state.color[y, x]
|
|
375
|
+
obj_params = object_state.state_params[y, x]
|
|
376
|
+
obj_gen = object_state.generation[y, x]
|
|
377
|
+
|
|
378
|
+
# Clear visuals at old position
|
|
379
|
+
new_color = object_state.color.at[y, x].set(
|
|
380
|
+
jnp.zeros(3, dtype=COLOR_DTYPE)
|
|
381
|
+
)
|
|
382
|
+
new_params = object_state.state_params.at[y, x].set(
|
|
383
|
+
jnp.zeros_like(obj_params)
|
|
384
|
+
)
|
|
385
|
+
|
|
367
386
|
# Find valid spawn locations in the same biome
|
|
368
387
|
biome_id = object_state.biome_id[y, x]
|
|
369
388
|
biome_mask = object_state.biome_id == biome_id
|
|
@@ -381,7 +400,7 @@ class ForagaxEnv(environment.Environment):
|
|
|
381
400
|
)
|
|
382
401
|
new_spawn_pos = valid_spawn_indices[random_idx]
|
|
383
402
|
|
|
384
|
-
# Place timer at the new random position
|
|
403
|
+
# Place timer and move properties at the new random position
|
|
385
404
|
new_respawn_timer = new_respawn_timer.at[
|
|
386
405
|
new_spawn_pos[0], new_spawn_pos[1]
|
|
387
406
|
].set(timer_val)
|
|
@@ -389,10 +408,24 @@ class ForagaxEnv(environment.Environment):
|
|
|
389
408
|
new_spawn_pos[0], new_spawn_pos[1]
|
|
390
409
|
].set(object_type)
|
|
391
410
|
|
|
411
|
+
# Move properties to new position
|
|
412
|
+
new_color = new_color.at[new_spawn_pos[0], new_spawn_pos[1]].set(
|
|
413
|
+
obj_color
|
|
414
|
+
)
|
|
415
|
+
new_params = new_params.at[new_spawn_pos[0], new_spawn_pos[1]].set(
|
|
416
|
+
obj_params
|
|
417
|
+
)
|
|
418
|
+
new_generation = object_state.generation.at[
|
|
419
|
+
new_spawn_pos[0], new_spawn_pos[1]
|
|
420
|
+
].set(obj_gen)
|
|
421
|
+
|
|
392
422
|
return object_state.replace(
|
|
393
423
|
object_id=new_object_id,
|
|
394
424
|
respawn_timer=new_respawn_timer,
|
|
395
425
|
respawn_object_id=new_respawn_object_id,
|
|
426
|
+
color=new_color,
|
|
427
|
+
state_params=new_params,
|
|
428
|
+
generation=new_generation,
|
|
396
429
|
)
|
|
397
430
|
|
|
398
431
|
return jax.lax.cond(random_respawn, place_randomly, place_at_position)
|
|
@@ -604,22 +637,22 @@ class ForagaxEnv(environment.Environment):
|
|
|
604
637
|
# Compute reward at each grid position
|
|
605
638
|
fixed_key = jax.random.key(0) # Fixed key for deterministic reward computation
|
|
606
639
|
|
|
607
|
-
def compute_reward(obj_id, params):
|
|
608
|
-
|
|
609
|
-
obj_id
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
fixed_key,
|
|
615
|
-
params.astype(jnp.float32),
|
|
616
|
-
),
|
|
617
|
-
lambda: 0.0,
|
|
640
|
+
def compute_reward(obj_id, params, timer):
|
|
641
|
+
reward = jax.lax.switch(
|
|
642
|
+
obj_id.astype(jnp.int32),
|
|
643
|
+
self.reward_fns,
|
|
644
|
+
state.time,
|
|
645
|
+
fixed_key,
|
|
646
|
+
params.astype(jnp.float32),
|
|
618
647
|
)
|
|
648
|
+
# Only show reward for objects that are fully present (no timer)
|
|
649
|
+
mask = (obj_id > 0) & (timer == 0)
|
|
650
|
+
return jnp.where(mask, reward, 0.0)
|
|
619
651
|
|
|
620
652
|
reward_grid = jax.vmap(jax.vmap(compute_reward))(
|
|
621
653
|
object_state.object_id.astype(ID_DTYPE),
|
|
622
654
|
object_state.state_params.astype(PARAM_DTYPE),
|
|
655
|
+
object_state.respawn_timer.astype(TIMER_DTYPE),
|
|
623
656
|
)
|
|
624
657
|
return reward_grid
|
|
625
658
|
|
|
@@ -1465,68 +1498,34 @@ class ForagaxEnv(environment.Environment):
|
|
|
1465
1498
|
return spaces.Box(0, 1, obs_shape, float)
|
|
1466
1499
|
|
|
1467
1500
|
def _compute_reward_grid(
|
|
1468
|
-
self, state: EnvState, object_id=None, state_params=None
|
|
1501
|
+
self, state: EnvState, object_id=None, state_params=None, respawn_timer=None
|
|
1469
1502
|
) -> jax.Array:
|
|
1470
1503
|
"""Compute rewards for given positions. If no grid provided, uses full world."""
|
|
1471
1504
|
if object_id is None:
|
|
1472
1505
|
object_id = state.object_state.object_id
|
|
1473
1506
|
if state_params is None:
|
|
1474
1507
|
state_params = state.object_state.state_params
|
|
1508
|
+
if respawn_timer is None:
|
|
1509
|
+
respawn_timer = state.object_state.respawn_timer
|
|
1475
1510
|
|
|
1476
1511
|
fixed_key = jax.random.key(0) # Fixed key for deterministic reward computation
|
|
1477
1512
|
|
|
1478
|
-
def compute_reward(obj_id, params):
|
|
1479
|
-
|
|
1480
|
-
obj_id
|
|
1481
|
-
|
|
1482
|
-
|
|
1483
|
-
|
|
1484
|
-
|
|
1513
|
+
def compute_reward(obj_id, params, timer):
|
|
1514
|
+
reward = jax.lax.switch(
|
|
1515
|
+
obj_id.astype(jnp.int32),
|
|
1516
|
+
self.reward_fns,
|
|
1517
|
+
state.time,
|
|
1518
|
+
fixed_key,
|
|
1519
|
+
params.astype(jnp.float32),
|
|
1485
1520
|
)
|
|
1521
|
+
# Only show reward for objects that are fully present (no timer)
|
|
1522
|
+
mask = (obj_id > 0) & (timer == 0)
|
|
1523
|
+
return jnp.where(mask, reward, 0.0)
|
|
1486
1524
|
|
|
1487
|
-
reward_grid = jax.vmap(jax.vmap(compute_reward))(
|
|
1488
|
-
|
|
1489
|
-
|
|
1490
|
-
def _reward_to_color(self, reward: jax.Array) -> jax.Array:
|
|
1491
|
-
"""Convert reward value to RGB color using diverging gradient.
|
|
1492
|
-
|
|
1493
|
-
Args:
|
|
1494
|
-
reward: Reward value (typically -1 to +1)
|
|
1495
|
-
|
|
1496
|
-
Returns:
|
|
1497
|
-
RGB color array with shape (..., 3) and dtype uint8
|
|
1498
|
-
"""
|
|
1499
|
-
# Diverging gradient: +1 = green (0, 255, 0), 0 = white (255, 255, 255), -1 = magenta (255, 0, 255)
|
|
1500
|
-
# Clamp reward to [-1, 1] range for color mapping
|
|
1501
|
-
reward_clamped = jnp.clip(reward, -1.0, 1.0)
|
|
1502
|
-
|
|
1503
|
-
# For positive rewards: interpolate from white to green
|
|
1504
|
-
# For negative rewards: interpolate from white to magenta
|
|
1505
|
-
# At reward = 0: white (255, 255, 255)
|
|
1506
|
-
# At reward = +1: green (0, 255, 0)
|
|
1507
|
-
# At reward = -1: magenta (255, 0, 255)
|
|
1508
|
-
|
|
1509
|
-
red_component = jnp.where(
|
|
1510
|
-
reward_clamped >= 0,
|
|
1511
|
-
(1 - reward_clamped) * 255, # Fade from white to green: 255 -> 0
|
|
1512
|
-
255, # Stay at 255 for all negative rewards
|
|
1513
|
-
)
|
|
1514
|
-
|
|
1515
|
-
green_component = jnp.where(
|
|
1516
|
-
reward_clamped >= 0,
|
|
1517
|
-
255, # Stay at 255 for all positive rewards
|
|
1518
|
-
(1 + reward_clamped) * 255, # Fade from white to magenta: 255 -> 0
|
|
1519
|
-
)
|
|
1520
|
-
|
|
1521
|
-
blue_component = jnp.where(
|
|
1522
|
-
reward_clamped >= 0,
|
|
1523
|
-
(1 - reward_clamped) * 255, # Fade from white to green: 255 -> 0
|
|
1524
|
-
255, # Stay at 255 for all negative rewards
|
|
1525
|
+
reward_grid = jax.vmap(jax.vmap(compute_reward))(
|
|
1526
|
+
object_id, state_params, respawn_timer
|
|
1525
1527
|
)
|
|
1526
|
-
|
|
1527
|
-
return jnp.stack(
|
|
1528
|
-
[red_component, green_component, blue_component], axis=-1
|
|
1529
|
-
).astype(jnp.uint8)
|
|
1528
|
+
return reward_grid
|
|
1530
1529
|
|
|
1531
1530
|
@partial(jax.jit, static_argnames=("self", "render_mode"))
|
|
1532
1531
|
def render(
|
|
@@ -1535,13 +1534,7 @@ class ForagaxEnv(environment.Environment):
|
|
|
1535
1534
|
params: EnvParams,
|
|
1536
1535
|
render_mode: str = "world",
|
|
1537
1536
|
):
|
|
1538
|
-
"""Render the environment state.
|
|
1539
|
-
|
|
1540
|
-
Args:
|
|
1541
|
-
state: Current environment state
|
|
1542
|
-
params: Environment parameters
|
|
1543
|
-
render_mode: One of "world", "world_true", "world_reward", "aperture", "aperture_true", "aperture_reward"
|
|
1544
|
-
"""
|
|
1537
|
+
"""Render the environment state."""
|
|
1545
1538
|
is_world_mode = render_mode in ("world", "world_true", "world_reward")
|
|
1546
1539
|
is_aperture_mode = render_mode in (
|
|
1547
1540
|
"aperture",
|
|
@@ -1552,27 +1545,12 @@ class ForagaxEnv(environment.Environment):
|
|
|
1552
1545
|
is_reward_mode = render_mode in ("world_reward", "aperture_reward")
|
|
1553
1546
|
|
|
1554
1547
|
if is_world_mode:
|
|
1555
|
-
|
|
1556
|
-
|
|
1557
|
-
|
|
1558
|
-
|
|
1559
|
-
|
|
1560
|
-
|
|
1561
|
-
empty_mask = state.object_state.object_id == 0
|
|
1562
|
-
white_color = jnp.array([255, 255, 255], dtype=jnp.uint8)
|
|
1563
|
-
img = jnp.where(empty_mask[..., None], white_color, img)
|
|
1564
|
-
else:
|
|
1565
|
-
# Use default object colors
|
|
1566
|
-
img = jnp.zeros((self.size[1], self.size[0], 3))
|
|
1567
|
-
render_grid = state.object_state.object_id
|
|
1568
|
-
|
|
1569
|
-
def update_image(i, img):
|
|
1570
|
-
color = self.object_colors[i]
|
|
1571
|
-
mask = render_grid == i
|
|
1572
|
-
img = jnp.where(mask[..., None], color, img)
|
|
1573
|
-
return img
|
|
1574
|
-
|
|
1575
|
-
img = jax.lax.fori_loop(0, len(self.object_ids), update_image, img)
|
|
1548
|
+
img = get_base_image(
|
|
1549
|
+
state.object_state.object_id,
|
|
1550
|
+
state.object_state.color,
|
|
1551
|
+
self.object_colors,
|
|
1552
|
+
self.dynamic_biomes,
|
|
1553
|
+
)
|
|
1576
1554
|
|
|
1577
1555
|
# Define constants for all world modes
|
|
1578
1556
|
alpha = 0.2
|
|
@@ -1582,26 +1560,18 @@ class ForagaxEnv(environment.Environment):
|
|
|
1582
1560
|
|
|
1583
1561
|
if is_reward_mode:
|
|
1584
1562
|
# Construct 3x intermediate image
|
|
1585
|
-
# Each cell is 3x3, with reward color in center
|
|
1586
1563
|
reward_grid = self._compute_reward_grid(state)
|
|
1587
|
-
reward_colors =
|
|
1588
|
-
|
|
1589
|
-
# Each cell has its base color in 8 pixels and reward color in 1 (center)
|
|
1590
|
-
# Create a 3x3 pattern mask for center pixels
|
|
1591
|
-
cell_mask = jnp.array(
|
|
1592
|
-
[[False, False, False], [False, True, False], [False, False, False]]
|
|
1593
|
-
)
|
|
1594
|
-
grid_reward_mask = jnp.tile(cell_mask, (self.size[1], self.size[0]))
|
|
1564
|
+
reward_colors = reward_to_color(reward_grid)
|
|
1595
1565
|
|
|
1596
|
-
# Repeat base colors
|
|
1566
|
+
# Repeat base colors to 3x scale
|
|
1597
1567
|
base_img_x3 = jnp.repeat(jnp.repeat(img, 3, axis=0), 3, axis=1)
|
|
1598
|
-
reward_colors_x3 = jnp.repeat(
|
|
1599
|
-
jnp.repeat(reward_colors, 3, axis=0), 3, axis=1
|
|
1600
|
-
)
|
|
1601
1568
|
|
|
1602
|
-
# Composite base and reward colors
|
|
1603
|
-
img =
|
|
1604
|
-
|
|
1569
|
+
# Composite base and reward colors using helper
|
|
1570
|
+
img = apply_reward_overlay(
|
|
1571
|
+
base_img_x3,
|
|
1572
|
+
reward_colors,
|
|
1573
|
+
reward_grid,
|
|
1574
|
+
self.size,
|
|
1605
1575
|
)
|
|
1606
1576
|
|
|
1607
1577
|
# Tint the aperture region at 3x scale
|
|
@@ -1647,83 +1617,72 @@ class ForagaxEnv(environment.Environment):
|
|
|
1647
1617
|
img, state.object_state.object_id, self.size, len(self.object_ids)
|
|
1648
1618
|
)
|
|
1649
1619
|
|
|
1650
|
-
# Add grid lines
|
|
1651
|
-
|
|
1652
|
-
col_grid = (jnp.arange(self.size[0] * 24) % 24) == 0
|
|
1653
|
-
# skip first rows/cols as they are borders or managed by caller
|
|
1654
|
-
row_grid = row_grid.at[0].set(False)
|
|
1655
|
-
col_grid = col_grid.at[0].set(False)
|
|
1656
|
-
grid_mask = row_grid[:, None] | col_grid[None, :]
|
|
1657
|
-
img = jnp.where(grid_mask[..., None], self.grid_color_jax, img)
|
|
1620
|
+
# Add grid lines
|
|
1621
|
+
img = apply_grid_lines(img, self.size, self.grid_color_jax)
|
|
1658
1622
|
|
|
1659
1623
|
elif is_aperture_mode:
|
|
1660
1624
|
obs_grid = state.object_state.object_id
|
|
1661
1625
|
aperture = self._get_aperture(obs_grid, state.pos)
|
|
1662
1626
|
|
|
1663
|
-
|
|
1664
|
-
|
|
1665
|
-
|
|
1666
|
-
|
|
1667
|
-
)
|
|
1668
|
-
img = state.object_state.color[y_coords_adj, x_coords_adj]
|
|
1627
|
+
y_coords, x_coords, y_coords_adj, x_coords_adj = (
|
|
1628
|
+
self._compute_aperture_coordinates(state.pos)
|
|
1629
|
+
)
|
|
1630
|
+
color_state = state.object_state.color[y_coords_adj, x_coords_adj]
|
|
1669
1631
|
|
|
1670
|
-
|
|
1671
|
-
|
|
1672
|
-
|
|
1673
|
-
|
|
1674
|
-
|
|
1675
|
-
|
|
1676
|
-
|
|
1677
|
-
|
|
1678
|
-
|
|
1679
|
-
|
|
1680
|
-
|
|
1681
|
-
|
|
1682
|
-
|
|
1683
|
-
|
|
1684
|
-
img = jnp.where(out_of_bounds[..., None], padding_color, img)
|
|
1685
|
-
else:
|
|
1686
|
-
# Use default object colors
|
|
1687
|
-
aperture_one_hot = jax.nn.one_hot(aperture, len(self.object_ids))
|
|
1688
|
-
img = jnp.tensordot(aperture_one_hot, self.object_colors, axes=1)
|
|
1632
|
+
img = get_base_image(
|
|
1633
|
+
aperture,
|
|
1634
|
+
color_state,
|
|
1635
|
+
self.object_colors,
|
|
1636
|
+
self.dynamic_biomes,
|
|
1637
|
+
)
|
|
1638
|
+
|
|
1639
|
+
if self.dynamic_biomes and self.nowrap:
|
|
1640
|
+
# For out-of-bounds, use padding object color
|
|
1641
|
+
y_out = (y_coords < 0) | (y_coords >= self.size[1])
|
|
1642
|
+
x_out = (x_coords < 0) | (x_coords >= self.size[0])
|
|
1643
|
+
out_of_bounds = y_out | x_out
|
|
1644
|
+
padding_color = jnp.array(self.objects[-1].color, dtype=jnp.float32)
|
|
1645
|
+
img = jnp.where(out_of_bounds[..., None], padding_color, img)
|
|
1689
1646
|
|
|
1690
1647
|
if is_reward_mode:
|
|
1691
1648
|
# Scale image by 3 to create space for reward visualization
|
|
1692
1649
|
img = img.astype(jnp.uint8)
|
|
1693
1650
|
img = jax.image.resize(
|
|
1694
1651
|
img,
|
|
1695
|
-
(
|
|
1652
|
+
(
|
|
1653
|
+
self.aperture_size[0] * 3,
|
|
1654
|
+
self.aperture_size[1] * 3,
|
|
1655
|
+
3,
|
|
1656
|
+
),
|
|
1696
1657
|
jax.image.ResizeMethod.NEAREST,
|
|
1697
1658
|
)
|
|
1698
1659
|
|
|
1699
1660
|
# Compute rewards for aperture region
|
|
1700
|
-
y_coords, x_coords, y_coords_adj, x_coords_adj = (
|
|
1701
|
-
self._compute_aperture_coordinates(state.pos)
|
|
1702
|
-
)
|
|
1703
|
-
|
|
1704
|
-
# Get reward grid only for aperture region
|
|
1705
|
-
aperture_object_ids = state.object_state.object_id[
|
|
1706
|
-
y_coords_adj, x_coords_adj
|
|
1707
|
-
]
|
|
1708
1661
|
aperture_params = state.object_state.state_params[
|
|
1709
1662
|
y_coords_adj, x_coords_adj
|
|
1710
1663
|
]
|
|
1664
|
+
aperture_timer = self._get_aperture(
|
|
1665
|
+
state.object_state.respawn_timer, state.pos
|
|
1666
|
+
)
|
|
1711
1667
|
aperture_rewards = self._compute_reward_grid(
|
|
1712
|
-
state,
|
|
1668
|
+
state, aperture, aperture_params, aperture_timer
|
|
1713
1669
|
)
|
|
1714
1670
|
|
|
1715
1671
|
# Convert rewards to colors
|
|
1716
|
-
reward_colors =
|
|
1672
|
+
reward_colors = reward_to_color(aperture_rewards)
|
|
1717
1673
|
|
|
1718
|
-
#
|
|
1719
|
-
|
|
1720
|
-
|
|
1721
|
-
|
|
1674
|
+
# Apply reward overlay using helper
|
|
1675
|
+
img = apply_reward_overlay(
|
|
1676
|
+
img,
|
|
1677
|
+
reward_colors,
|
|
1678
|
+
aperture_rewards,
|
|
1679
|
+
self.aperture_size,
|
|
1680
|
+
)
|
|
1722
1681
|
|
|
1723
1682
|
# Draw agent in the center (all 9 cells of the 3x3 block)
|
|
1724
1683
|
center_y, center_x = (
|
|
1725
|
-
self.aperture_size[1] // 2,
|
|
1726
1684
|
self.aperture_size[0] // 2,
|
|
1685
|
+
self.aperture_size[1] // 2,
|
|
1727
1686
|
)
|
|
1728
1687
|
agent_offsets = jnp.array(
|
|
1729
1688
|
[[dy, dx] for dy in range(3) for dx in range(3)]
|
|
@@ -1757,17 +1716,13 @@ class ForagaxEnv(environment.Environment):
|
|
|
1757
1716
|
)
|
|
1758
1717
|
|
|
1759
1718
|
if is_true_mode:
|
|
1760
|
-
# Apply true object borders
|
|
1719
|
+
# Apply true object borders
|
|
1761
1720
|
img = apply_true_borders(
|
|
1762
1721
|
img, aperture, self.aperture_size, len(self.object_ids)
|
|
1763
1722
|
)
|
|
1764
1723
|
|
|
1765
|
-
# Add grid lines
|
|
1766
|
-
|
|
1767
|
-
row_indices = jnp.arange(1, self.aperture_size[0]) * 24
|
|
1768
|
-
col_indices = jnp.arange(1, self.aperture_size[1]) * 24
|
|
1769
|
-
img = img.at[row_indices, :].set(grid_color)
|
|
1770
|
-
img = img.at[:, col_indices].set(grid_color)
|
|
1724
|
+
# Add grid lines
|
|
1725
|
+
img = apply_grid_lines(img, self.aperture_size, self.grid_color_jax)
|
|
1771
1726
|
|
|
1772
1727
|
else:
|
|
1773
1728
|
raise ValueError(f"Unknown render_mode: {render_mode}")
|
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
"""Rendering utilities for Foragax environments."""
|
|
2
|
+
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
|
|
5
|
+
import jax
|
|
6
|
+
import jax.numpy as jnp
|
|
7
|
+
|
|
8
|
+
from foragax.colors import hsv_to_rgb_255
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def apply_true_borders(
|
|
12
|
+
base_img: jax.Array,
|
|
13
|
+
true_grid: jax.Array,
|
|
14
|
+
grid_size: Tuple[int, int],
|
|
15
|
+
num_objects: int,
|
|
16
|
+
) -> jax.Array:
|
|
17
|
+
"""Apply true object borders by overlaying HSV border colors on border pixels.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
base_img: Base image with object colors
|
|
21
|
+
true_grid: Grid of object IDs for determining border colors
|
|
22
|
+
grid_size: (height, width) of the grid
|
|
23
|
+
num_objects: Number of object types
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
Image with HSV borders overlaid on border pixels
|
|
27
|
+
"""
|
|
28
|
+
# Create HSV border colors for each object type
|
|
29
|
+
hues = jnp.linspace(0, 1, num_objects, endpoint=False)
|
|
30
|
+
|
|
31
|
+
# Convert HSV to RGB for border colors
|
|
32
|
+
border_colors = hsv_to_rgb_255(hues[true_grid])
|
|
33
|
+
|
|
34
|
+
# Resize border colors to match rendered image size
|
|
35
|
+
border_img = jax.image.resize(
|
|
36
|
+
border_colors,
|
|
37
|
+
(grid_size[0] * 24, grid_size[1] * 24, 3),
|
|
38
|
+
jax.image.ResizeMethod.NEAREST,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
# Create border mask (2-pixel thick borders) using vectorized modulo operations
|
|
42
|
+
img_height, img_width = grid_size[0] * 24, grid_size[1] * 24
|
|
43
|
+
y_idx = jnp.arange(img_height) % 24
|
|
44
|
+
x_idx = jnp.arange(img_width) % 24
|
|
45
|
+
|
|
46
|
+
# Border pixels are those with offset 0, 1, 22, or 23 within each 24x24 cell
|
|
47
|
+
is_border_row = (y_idx < 2) | (y_idx >= 22)
|
|
48
|
+
is_border_col = (x_idx < 2) | (x_idx >= 22)
|
|
49
|
+
border_mask = is_border_row[:, None] | is_border_col[None, :]
|
|
50
|
+
|
|
51
|
+
# Apply border mask: use HSV border colors for border pixels, base colors elsewhere
|
|
52
|
+
result_img = jnp.where(border_mask[..., None], border_img, base_img)
|
|
53
|
+
return result_img
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def reward_to_color(reward: jax.Array) -> jax.Array:
|
|
57
|
+
"""Convert reward value to RGB color using diverging gradient.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
reward: Reward value (typically -1 to +1)
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
RGB color array with shape (..., 3) and dtype uint8
|
|
64
|
+
"""
|
|
65
|
+
# Diverging gradient: +1 = green (0, 255, 0), 0 = white (255, 255, 255), -1 = magenta (255, 0, 255)
|
|
66
|
+
# Clamp reward to [-1, 1] range for color mapping
|
|
67
|
+
reward_clamped = jnp.clip(reward, -1.0, 1.0)
|
|
68
|
+
|
|
69
|
+
# For positive rewards: interpolate from white to green
|
|
70
|
+
# For negative rewards: interpolate from white to magenta
|
|
71
|
+
# At reward = 0: white (255, 255, 255)
|
|
72
|
+
# At reward = +1: green (0, 255, 0)
|
|
73
|
+
# At reward = -1: magenta (255, 0, 255)
|
|
74
|
+
|
|
75
|
+
red_component = jnp.where(
|
|
76
|
+
reward_clamped >= 0,
|
|
77
|
+
(1 - reward_clamped) * 255, # Fade from white to green: 255 -> 0
|
|
78
|
+
255, # Stay at 255 for all negative rewards
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
green_component = jnp.where(
|
|
82
|
+
reward_clamped >= 0,
|
|
83
|
+
255, # Stay at 255 for all positive rewards
|
|
84
|
+
(1 + reward_clamped) * 255, # Fade from white to magenta: 255 -> 0
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
blue_component = jnp.where(
|
|
88
|
+
reward_clamped >= 0,
|
|
89
|
+
(1 - reward_clamped) * 255, # Fade from white to green: 255 -> 0
|
|
90
|
+
255, # Stay at 255 for all negative rewards
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
return jnp.stack([red_component, green_component, blue_component], axis=-1).astype(
|
|
94
|
+
jnp.uint8
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def get_base_image(
|
|
99
|
+
object_id: jax.Array,
|
|
100
|
+
color_state: jax.Array,
|
|
101
|
+
object_colors: jax.Array,
|
|
102
|
+
dynamic_biomes: bool,
|
|
103
|
+
) -> jax.Array:
|
|
104
|
+
"""Construct base RGB image from object IDs or colors."""
|
|
105
|
+
if dynamic_biomes:
|
|
106
|
+
# Use per-instance colors from state
|
|
107
|
+
img = color_state.copy()
|
|
108
|
+
# Mask empty cells (object_id == 0) to white
|
|
109
|
+
empty_mask = object_id == 0
|
|
110
|
+
white_color = jnp.array([255, 255, 255], dtype=jnp.uint8)
|
|
111
|
+
img = jnp.where(empty_mask[..., None], white_color, img)
|
|
112
|
+
else:
|
|
113
|
+
# Map object IDs to colors
|
|
114
|
+
img = object_colors[object_id]
|
|
115
|
+
|
|
116
|
+
return img.astype(jnp.uint8)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def apply_grid_lines(
|
|
120
|
+
img: jax.Array,
|
|
121
|
+
grid_size: Tuple[int, int],
|
|
122
|
+
grid_color: jax.Array,
|
|
123
|
+
cell_size: int = 24,
|
|
124
|
+
) -> jax.Array:
|
|
125
|
+
"""Apply grid lines to the image."""
|
|
126
|
+
row_grid = (jnp.arange(grid_size[0] * cell_size) % cell_size) == 0
|
|
127
|
+
col_grid = (jnp.arange(grid_size[1] * cell_size) % cell_size) == 0
|
|
128
|
+
# skip first rows/cols as they are borders or managed by caller
|
|
129
|
+
row_grid = row_grid.at[0].set(False)
|
|
130
|
+
col_grid = col_grid.at[0].set(False)
|
|
131
|
+
grid_mask = row_grid[:, None] | col_grid[None, :]
|
|
132
|
+
return jnp.where(grid_mask[..., None], grid_color, img)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def apply_reward_overlay(
|
|
136
|
+
base_img: jax.Array,
|
|
137
|
+
reward_colors: jax.Array,
|
|
138
|
+
reward_grid: jax.Array,
|
|
139
|
+
grid_size: Tuple[int, int],
|
|
140
|
+
) -> jax.Array:
|
|
141
|
+
"""Apply reward visualization overlay (center dots) to the image.
|
|
142
|
+
|
|
143
|
+
Only applies dots where the reward is non-zero (abs > 1e-5).
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
base_img: Base image at 3x scale (each cell is 3x3)
|
|
147
|
+
reward_colors: Array of RGB colors for rewards
|
|
148
|
+
reward_grid: Grid of reward values
|
|
149
|
+
grid_size: (height, width) of the grid
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
Image with reward dots overlaid
|
|
153
|
+
"""
|
|
154
|
+
# Create a 3x3 pattern mask for center pixels
|
|
155
|
+
cell_mask = jnp.array(
|
|
156
|
+
[[False, False, False], [False, True, False], [False, False, False]]
|
|
157
|
+
)
|
|
158
|
+
grid_reward_mask = jnp.tile(cell_mask, grid_size)
|
|
159
|
+
|
|
160
|
+
# Only show reward where reward is meaningfully non-zero
|
|
161
|
+
reward_nonzero = jnp.abs(reward_grid) > 1e-5
|
|
162
|
+
# Expand to 3x scale
|
|
163
|
+
reward_nonzero_x3 = jnp.repeat(jnp.repeat(reward_nonzero, 3, axis=0), 3, axis=1)
|
|
164
|
+
|
|
165
|
+
# Final mask: center pixel of a cell AND cell has a non-zero reward
|
|
166
|
+
composite_mask = grid_reward_mask & reward_nonzero_x3
|
|
167
|
+
|
|
168
|
+
# Repeat reward colors to 3x to match image scale
|
|
169
|
+
reward_colors_x3 = jnp.repeat(jnp.repeat(reward_colors, 3, axis=0), 3, axis=1)
|
|
170
|
+
|
|
171
|
+
return jnp.where(composite_mask[..., None], reward_colors_x3, base_img)
|