continual-foragax 0.20.1__tar.gz → 0.22.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/PKG-INFO +1 -1
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/pyproject.toml +2 -2
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/continual_foragax.egg-info/PKG-INFO +1 -1
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/continual_foragax.egg-info/SOURCES.txt +1 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/env.py +63 -51
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/objects.py +28 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/registry.py +31 -1
- continual_foragax-0.22.0/tests/test_benchmark.py +239 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/tests/test_foragax.py +120 -233
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/README.md +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/setup.cfg +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/continual_foragax.egg-info/dependency_links.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/continual_foragax.egg-info/entry_points.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/continual_foragax.egg-info/requires.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/continual_foragax.egg-info/top_level.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/__init__.py +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/colors.py +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID100897.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID100928.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID100929.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID100930.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID100931.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106714.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106715.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106716.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106717.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106718.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106930.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106931.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106932.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106933.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106934.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106935.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106936.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106937.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106938.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106939.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106940.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106941.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106942.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106943.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106994.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106995.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106996.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106997.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106998.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID106999.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107000.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107001.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107002.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107003.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107004.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107005.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107006.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107007.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107008.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107009.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107010.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107011.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107012.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107013.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107014.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107015.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107016.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107017.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107018.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107019.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107020.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107021.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107022.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107023.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107024.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107025.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107026.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107027.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107028.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107029.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107030.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107031.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107032.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107033.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107034.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107035.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107036.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107037.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107038.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107039.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107040.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107041.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107042.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107043.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107044.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107045.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107046.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107047.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107048.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107049.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107050.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107051.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107052.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107053.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107054.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107055.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107056.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107057.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107058.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107059.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107060.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107061.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107062.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107063.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107064.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107065.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107066.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107067.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107068.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107069.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107070.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID107071.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID115808.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID115812.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID146811.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID156831.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID156835.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID156839.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID156843.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID156847.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID156851.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID156855.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID156859.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID156863.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID156867.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID156871.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID156875.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID156879.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID156883.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/TG_SOUID156887.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/elements.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/metadata.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/data/ECA_non-blended_custom/sources.txt +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/rendering.py +0 -0
- {continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/foragax/weather.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
[project]
|
2
2
|
name = "continual-foragax"
|
3
|
-
version = "0.
|
3
|
+
version = "0.22.0"
|
4
4
|
description = "A continual reinforcement learning benchmark"
|
5
5
|
readme = "README.md"
|
6
6
|
authors = [
|
@@ -30,7 +30,7 @@ build-backend = "setuptools.build_meta"
|
|
30
30
|
[tool]
|
31
31
|
[tool.commitizen]
|
32
32
|
name = "cz_conventional_commits"
|
33
|
-
version = "0.
|
33
|
+
version = "0.22.0"
|
34
34
|
tag_format = "$version"
|
35
35
|
version_files = ["pyproject.toml"]
|
36
36
|
|
{continual_foragax-0.20.1 → continual_foragax-0.22.0}/src/continual_foragax.egg-info/SOURCES.txt
RENAMED
@@ -136,4 +136,5 @@ src/foragax/data/ECA_non-blended_custom/TG_SOUID156887.txt
|
|
136
136
|
src/foragax/data/ECA_non-blended_custom/elements.txt
|
137
137
|
src/foragax/data/ECA_non-blended_custom/metadata.txt
|
138
138
|
src/foragax/data/ECA_non-blended_custom/sources.txt
|
139
|
+
tests/test_benchmark.py
|
139
140
|
tests/test_foragax.py
|
@@ -10,6 +10,7 @@ from typing import Any, Dict, Tuple, Union
|
|
10
10
|
|
11
11
|
import jax
|
12
12
|
import jax.numpy as jnp
|
13
|
+
import numpy as np
|
13
14
|
from flax import struct
|
14
15
|
from gymnax.environments import environment, spaces
|
15
16
|
|
@@ -66,13 +67,16 @@ class ForagaxEnv(environment.Environment):
|
|
66
67
|
|
67
68
|
def __init__(
|
68
69
|
self,
|
70
|
+
name: str = "Foragax-v0",
|
69
71
|
size: Union[Tuple[int, int], int] = (10, 10),
|
70
72
|
aperture_size: Union[Tuple[int, int], int] = (5, 5),
|
71
73
|
objects: Tuple[BaseForagaxObject, ...] = (),
|
72
74
|
biomes: Tuple[Biome, ...] = (Biome(object_frequencies=()),),
|
73
75
|
nowrap: bool = False,
|
76
|
+
deterministic_spawn: bool = False,
|
74
77
|
):
|
75
78
|
super().__init__()
|
79
|
+
self._name = name
|
76
80
|
if isinstance(size, int):
|
77
81
|
size = (size, size)
|
78
82
|
self.size = size
|
@@ -81,6 +85,7 @@ class ForagaxEnv(environment.Environment):
|
|
81
85
|
aperture_size = (aperture_size, aperture_size)
|
82
86
|
self.aperture_size = aperture_size
|
83
87
|
self.nowrap = nowrap
|
88
|
+
self.deterministic_spawn = deterministic_spawn
|
84
89
|
objects = (EMPTY,) + objects
|
85
90
|
if self.nowrap:
|
86
91
|
objects = objects + (PADDING,)
|
@@ -103,12 +108,35 @@ class ForagaxEnv(environment.Environment):
|
|
103
108
|
self.biome_object_frequencies = jnp.array(
|
104
109
|
[b.object_frequencies for b in biomes]
|
105
110
|
)
|
106
|
-
self.biome_starts =
|
111
|
+
self.biome_starts = np.array(
|
107
112
|
[b.start if b.start is not None else (-1, -1) for b in biomes]
|
108
113
|
)
|
109
|
-
self.biome_stops =
|
114
|
+
self.biome_stops = np.array(
|
110
115
|
[b.stop if b.stop is not None else (-1, -1) for b in biomes]
|
111
116
|
)
|
117
|
+
self.biome_sizes = np.prod(self.biome_stops - self.biome_starts, axis=1)
|
118
|
+
self.biome_masks = []
|
119
|
+
for i in range(self.biome_object_frequencies.shape[0]):
|
120
|
+
# Create mask for the biome
|
121
|
+
start = jax.lax.select(
|
122
|
+
self.biome_starts[i, 0] == -1,
|
123
|
+
jnp.array([0, 0]),
|
124
|
+
self.biome_starts[i],
|
125
|
+
)
|
126
|
+
stop = jax.lax.select(
|
127
|
+
self.biome_stops[i, 0] == -1,
|
128
|
+
jnp.array(self.size),
|
129
|
+
self.biome_stops[i],
|
130
|
+
)
|
131
|
+
rows = jnp.arange(self.size[1])[:, None]
|
132
|
+
cols = jnp.arange(self.size[0])
|
133
|
+
mask = (
|
134
|
+
(rows >= start[1])
|
135
|
+
& (rows < stop[1])
|
136
|
+
& (cols >= start[0])
|
137
|
+
& (cols < stop[0])
|
138
|
+
)
|
139
|
+
self.biome_masks.append(mask)
|
112
140
|
|
113
141
|
@property
|
114
142
|
def default_params(self) -> EnvParams:
|
@@ -196,57 +224,18 @@ class ForagaxEnv(environment.Environment):
|
|
196
224
|
self, key: jax.Array, params: EnvParams
|
197
225
|
) -> Tuple[jax.Array, EnvState]:
|
198
226
|
"""Reset environment state."""
|
199
|
-
key, subkey = jax.random.split(key)
|
200
|
-
|
201
227
|
object_grid = jnp.zeros((self.size[1], self.size[0]), dtype=int)
|
202
|
-
|
203
|
-
iter_key = subkey
|
228
|
+
key, iter_key = jax.random.split(key)
|
204
229
|
for i in range(self.biome_object_frequencies.shape[0]):
|
205
230
|
iter_key, biome_key = jax.random.split(iter_key)
|
206
|
-
|
207
|
-
grid_rand = jax.random.uniform(biome_key, (self.size[1], self.size[0]))
|
208
|
-
|
209
|
-
# Create mask for the biome
|
210
|
-
start = jax.lax.select(
|
211
|
-
self.biome_starts[i, 0] == -1,
|
212
|
-
jnp.array([0, 0]),
|
213
|
-
self.biome_starts[i],
|
214
|
-
)
|
215
|
-
stop = jax.lax.select(
|
216
|
-
self.biome_stops[i, 0] == -1,
|
217
|
-
jnp.array(self.size),
|
218
|
-
self.biome_stops[i],
|
219
|
-
)
|
220
|
-
|
221
|
-
rows = jnp.arange(self.size[1])[:, None]
|
222
|
-
cols = jnp.arange(self.size[0])
|
231
|
+
mask = self.biome_masks[i]
|
223
232
|
|
224
|
-
|
225
|
-
(
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
# Generate objects for this biome and update the main grid
|
232
|
-
biome_freqs = self.biome_object_frequencies[i]
|
233
|
-
empty_freq = 1.0 - jnp.sum(biome_freqs)
|
234
|
-
all_freqs = jnp.concatenate([jnp.array([empty_freq]), biome_freqs])
|
235
|
-
|
236
|
-
cumulative_freqs = jnp.cumsum(
|
237
|
-
jnp.concatenate([jnp.array([0.0]), all_freqs])
|
238
|
-
)
|
239
|
-
|
240
|
-
# Determine which object to place in each cell
|
241
|
-
# The last object ID will be used for any value of grid_rand >= cumulative_freqs[-1]
|
242
|
-
# so we don't need to cap grid_rand
|
243
|
-
obj_ids_for_biome = jnp.arange(len(all_freqs))
|
244
|
-
cell_obj_ids = (
|
245
|
-
jnp.searchsorted(cumulative_freqs, grid_rand, side="right") - 1
|
246
|
-
)
|
247
|
-
biome_objects = obj_ids_for_biome[cell_obj_ids]
|
248
|
-
|
249
|
-
object_grid = jnp.where(mask, biome_objects, object_grid)
|
233
|
+
if self.deterministic_spawn:
|
234
|
+
biome_objects = self.generate_biome_new(i, biome_key)
|
235
|
+
object_grid = object_grid.at[mask].set(biome_objects)
|
236
|
+
else:
|
237
|
+
biome_objects = self.generate_biome_old(i, biome_key)
|
238
|
+
object_grid = jnp.where(mask, biome_objects, object_grid)
|
250
239
|
|
251
240
|
# Place agent in the center of the world and ensure the cell is empty.
|
252
241
|
agent_pos = jnp.array([self.size[0] // 2, self.size[1] // 2])
|
@@ -260,6 +249,25 @@ class ForagaxEnv(environment.Environment):
|
|
260
249
|
|
261
250
|
return self.get_obs(state, params), state
|
262
251
|
|
252
|
+
def generate_biome_old(self, i: int, biome_key: jax.Array):
|
253
|
+
biome_freqs = self.biome_object_frequencies[i]
|
254
|
+
grid_rand = jax.random.uniform(biome_key, (self.size[1], self.size[0]))
|
255
|
+
empty_freq = 1.0 - jnp.sum(biome_freqs)
|
256
|
+
all_freqs = jnp.concatenate([jnp.array([empty_freq]), biome_freqs])
|
257
|
+
cumulative_freqs = jnp.cumsum(jnp.concatenate([jnp.array([0.0]), all_freqs]))
|
258
|
+
biome_objects = jnp.searchsorted(cumulative_freqs, grid_rand, side="right") - 1
|
259
|
+
return biome_objects
|
260
|
+
|
261
|
+
def generate_biome_new(self, i: int, biome_key: jax.Array):
|
262
|
+
biome_freqs = self.biome_object_frequencies[i]
|
263
|
+
grid = jnp.linspace(0, 1, self.biome_sizes[i], endpoint=False)
|
264
|
+
biome_objects = len(biome_freqs) - jnp.searchsorted(
|
265
|
+
jnp.cumsum(biome_freqs[::-1]), grid, side="right"
|
266
|
+
)
|
267
|
+
flat_biome_objects = biome_objects.flatten()
|
268
|
+
shuffled_objects = jax.random.permutation(biome_key, flat_biome_objects)
|
269
|
+
return shuffled_objects
|
270
|
+
|
263
271
|
def is_terminal(self, state: EnvState, params: EnvParams) -> jax.Array:
|
264
272
|
"""Foragax is a continuing environment."""
|
265
273
|
return False
|
@@ -267,7 +275,7 @@ class ForagaxEnv(environment.Environment):
|
|
267
275
|
@property
|
268
276
|
def name(self) -> str:
|
269
277
|
"""Environment name."""
|
270
|
-
return
|
278
|
+
return self._name
|
271
279
|
|
272
280
|
@property
|
273
281
|
def num_actions(self) -> int:
|
@@ -438,13 +446,17 @@ class ForagaxObjectEnv(ForagaxEnv):
|
|
438
446
|
|
439
447
|
def __init__(
|
440
448
|
self,
|
449
|
+
name: str = "Foragax-v0",
|
441
450
|
size: Union[Tuple[int, int], int] = (10, 10),
|
442
451
|
aperture_size: Union[Tuple[int, int], int] = (5, 5),
|
443
452
|
objects: Tuple[BaseForagaxObject, ...] = (),
|
444
453
|
biomes: Tuple[Biome, ...] = (Biome(object_frequencies=()),),
|
445
454
|
nowrap: bool = False,
|
455
|
+
deterministic_spawn: bool = False,
|
446
456
|
):
|
447
|
-
super().__init__(
|
457
|
+
super().__init__(
|
458
|
+
name, size, aperture_size, objects, biomes, nowrap, deterministic_spawn
|
459
|
+
)
|
448
460
|
|
449
461
|
# Compute unique colors and mapping for partial observability
|
450
462
|
# Exclude EMPTY (index 0) from color channels
|
@@ -242,6 +242,34 @@ GREEN_FAKE_2 = NormalRegenForagaxObject(
|
|
242
242
|
mean_regen_delay=10,
|
243
243
|
std_regen_delay=1,
|
244
244
|
)
|
245
|
+
BROWN_MOREL_UNIFORM = DefaultForagaxObject(
|
246
|
+
name="brown_morel",
|
247
|
+
reward=10.0,
|
248
|
+
collectable=True,
|
249
|
+
color=(63, 30, 25),
|
250
|
+
regen_delay=(90, 110),
|
251
|
+
)
|
252
|
+
BROWN_OYSTER_UNIFORM = DefaultForagaxObject(
|
253
|
+
name="brown_oyster",
|
254
|
+
reward=1.0,
|
255
|
+
collectable=True,
|
256
|
+
color=(63, 30, 25),
|
257
|
+
regen_delay=(9, 11),
|
258
|
+
)
|
259
|
+
GREEN_DEATHCAP_UNIFORM = DefaultForagaxObject(
|
260
|
+
name="green_deathcap",
|
261
|
+
reward=-5.0,
|
262
|
+
collectable=True,
|
263
|
+
color=(0, 255, 0),
|
264
|
+
regen_delay=(9, 11),
|
265
|
+
)
|
266
|
+
GREEN_FAKE_UNIFORM = DefaultForagaxObject(
|
267
|
+
name="green_fake",
|
268
|
+
reward=0.0,
|
269
|
+
collectable=True,
|
270
|
+
color=(0, 255, 0),
|
271
|
+
regen_delay=(9, 11),
|
272
|
+
)
|
245
273
|
|
246
274
|
|
247
275
|
def create_weather_objects(
|
@@ -12,12 +12,16 @@ from foragax.env import (
|
|
12
12
|
from foragax.objects import (
|
13
13
|
BROWN_MOREL,
|
14
14
|
BROWN_MOREL_2,
|
15
|
+
BROWN_MOREL_UNIFORM,
|
15
16
|
BROWN_OYSTER,
|
17
|
+
BROWN_OYSTER_UNIFORM,
|
16
18
|
GREEN_DEATHCAP,
|
17
19
|
GREEN_DEATHCAP_2,
|
18
20
|
GREEN_DEATHCAP_3,
|
21
|
+
GREEN_DEATHCAP_UNIFORM,
|
19
22
|
GREEN_FAKE,
|
20
23
|
GREEN_FAKE_2,
|
24
|
+
GREEN_FAKE_UNIFORM,
|
21
25
|
LARGE_MOREL,
|
22
26
|
LARGE_OYSTER,
|
23
27
|
MEDIUM_MOREL,
|
@@ -147,6 +151,31 @@ ENV_CONFIGS: Dict[str, Dict[str, Any]] = {
|
|
147
151
|
"biomes": None,
|
148
152
|
"nowrap": True,
|
149
153
|
},
|
154
|
+
"ForagaxTwoBiome-v8": {
|
155
|
+
"size": None,
|
156
|
+
"aperture_size": None,
|
157
|
+
"objects": (
|
158
|
+
BROWN_MOREL_UNIFORM,
|
159
|
+
BROWN_OYSTER_UNIFORM,
|
160
|
+
GREEN_DEATHCAP_UNIFORM,
|
161
|
+
GREEN_FAKE_UNIFORM,
|
162
|
+
),
|
163
|
+
"biomes": None,
|
164
|
+
"nowrap": True,
|
165
|
+
},
|
166
|
+
"ForagaxTwoBiome-v9": {
|
167
|
+
"size": None,
|
168
|
+
"aperture_size": None,
|
169
|
+
"objects": (
|
170
|
+
BROWN_MOREL_UNIFORM,
|
171
|
+
BROWN_OYSTER_UNIFORM,
|
172
|
+
GREEN_DEATHCAP_UNIFORM,
|
173
|
+
GREEN_FAKE_UNIFORM,
|
174
|
+
),
|
175
|
+
"biomes": None,
|
176
|
+
"nowrap": True,
|
177
|
+
"deterministic_spawn": True,
|
178
|
+
},
|
150
179
|
"ForagaxTwoBiomeSmall-v1": {
|
151
180
|
"size": (16, 8),
|
152
181
|
"aperture_size": None,
|
@@ -217,7 +246,7 @@ def make(
|
|
217
246
|
if nowrap is not None:
|
218
247
|
config["nowrap"] = nowrap
|
219
248
|
|
220
|
-
if env_id
|
249
|
+
if env_id in ("ForagaxTwoBiome-v7", "ForagaxTwoBiome-v8", "ForagaxTwoBiome-v9"):
|
221
250
|
margin = aperture_size[1] // 2 + 1
|
222
251
|
width = 2 * margin + 9
|
223
252
|
config["size"] = (width, 15)
|
@@ -270,5 +299,6 @@ def make(
|
|
270
299
|
raise ValueError(f"Unknown observation type: {observation_type}")
|
271
300
|
|
272
301
|
env_class = env_class_map[observation_type]
|
302
|
+
config["name"] = env_id
|
273
303
|
|
274
304
|
return env_class(**config)
|
@@ -0,0 +1,239 @@
|
|
1
|
+
import chex
|
2
|
+
import jax
|
3
|
+
import jax.numpy as jnp
|
4
|
+
|
5
|
+
from foragax.env import Actions, Biome, ForagaxObjectEnv, ForagaxRGBEnv, ForagaxWorldEnv
|
6
|
+
from foragax.objects import FLOWER, WALL
|
7
|
+
|
8
|
+
|
9
|
+
def test_benchmark_vision(benchmark):
|
10
|
+
env = ForagaxObjectEnv(size=7, aperture_size=3, objects=(WALL,))
|
11
|
+
params = env.default_params
|
12
|
+
key = jax.random.key(0)
|
13
|
+
_, state = env.reset(key, params)
|
14
|
+
|
15
|
+
grid = jnp.zeros((7, 7), dtype=int)
|
16
|
+
grid = grid.at[4, 3].set(1)
|
17
|
+
grid = grid.at[5, 3].set(1)
|
18
|
+
grid = grid.at[2, 0].set(1)
|
19
|
+
state = state.replace(object_grid=grid)
|
20
|
+
|
21
|
+
@jax.jit
|
22
|
+
def _run(state, key):
|
23
|
+
key, step_key = jax.random.split(key)
|
24
|
+
obs, new_state, _, _, _ = env.step(step_key, state, Actions.DOWN, params)
|
25
|
+
return obs, new_state
|
26
|
+
|
27
|
+
# warm-up
|
28
|
+
obs, new_state = _run(state, key)
|
29
|
+
|
30
|
+
expected = jnp.zeros((3, 3, 1), dtype=int)
|
31
|
+
expected = expected.at[2, 1, 0].set(1)
|
32
|
+
|
33
|
+
chex.assert_trees_all_equal(new_state.pos, jnp.array([3, 3]))
|
34
|
+
chex.assert_trees_all_equal(obs, expected)
|
35
|
+
|
36
|
+
def benchmark_fn():
|
37
|
+
# use a fixed key for benchmark consistency
|
38
|
+
_run(state, jax.random.key(1))[0].block_until_ready()
|
39
|
+
|
40
|
+
benchmark(benchmark_fn)
|
41
|
+
|
42
|
+
|
43
|
+
def test_benchmark_creation(benchmark):
|
44
|
+
env = ForagaxObjectEnv(
|
45
|
+
size=1_000,
|
46
|
+
aperture_size=31,
|
47
|
+
objects=(WALL, FLOWER),
|
48
|
+
biomes=(Biome(object_frequencies=(0.05, 0.05)),),
|
49
|
+
)
|
50
|
+
params = env.default_params
|
51
|
+
|
52
|
+
@jax.jit
|
53
|
+
def _build(key):
|
54
|
+
_, state = env.reset(key, params)
|
55
|
+
return state
|
56
|
+
|
57
|
+
# no warm-up
|
58
|
+
|
59
|
+
def benchmark_fn():
|
60
|
+
_build(jax.random.key(1)).pos.block_until_ready()
|
61
|
+
|
62
|
+
benchmark(benchmark_fn)
|
63
|
+
|
64
|
+
|
65
|
+
def test_benchmark_small_env(benchmark):
|
66
|
+
env = ForagaxObjectEnv(
|
67
|
+
size=1_000,
|
68
|
+
aperture_size=11,
|
69
|
+
objects=(WALL, FLOWER),
|
70
|
+
biomes=(Biome(object_frequencies=(0.1, 0.1)),),
|
71
|
+
)
|
72
|
+
params = env.default_params
|
73
|
+
key = jax.random.key(0)
|
74
|
+
key, reset_key = jax.random.split(key)
|
75
|
+
_, state = env.reset(reset_key, params)
|
76
|
+
|
77
|
+
@jax.jit
|
78
|
+
def _run(state, key):
|
79
|
+
def f(carry, _):
|
80
|
+
state, key = carry
|
81
|
+
key, step_key = jax.random.split(key, 2)
|
82
|
+
_, new_state, _, _, _ = env.step(step_key, state, Actions.DOWN, params)
|
83
|
+
return (new_state, key), None
|
84
|
+
|
85
|
+
(final_state, _), _ = jax.lax.scan(f, (state, key), None, length=1000)
|
86
|
+
return final_state
|
87
|
+
|
88
|
+
key, run_key = jax.random.split(key)
|
89
|
+
_run(state, run_key).pos.block_until_ready()
|
90
|
+
|
91
|
+
def benchmark_fn():
|
92
|
+
key, run_key = jax.random.split(jax.random.key(1))
|
93
|
+
_run(state, run_key).pos.block_until_ready()
|
94
|
+
|
95
|
+
benchmark(benchmark_fn)
|
96
|
+
|
97
|
+
|
98
|
+
def test_benchmark_big_env(benchmark):
|
99
|
+
env = ForagaxObjectEnv(
|
100
|
+
size=10_000,
|
101
|
+
aperture_size=61,
|
102
|
+
objects=(WALL, FLOWER),
|
103
|
+
biomes=(Biome(object_frequencies=(0.05, 0.05)),),
|
104
|
+
)
|
105
|
+
params = env.default_params
|
106
|
+
key = jax.random.key(0)
|
107
|
+
|
108
|
+
# Reset is part of the setup, not benchmarked
|
109
|
+
key, reset_key = jax.random.split(key)
|
110
|
+
_, state = env.reset(reset_key, params)
|
111
|
+
|
112
|
+
@jax.jit
|
113
|
+
def _run(state, key):
|
114
|
+
def f(carry, _):
|
115
|
+
state, key = carry
|
116
|
+
key, step_key = jax.random.split(key, 2)
|
117
|
+
_, new_state, _, _, _ = env.step(step_key, state, Actions.DOWN, params)
|
118
|
+
return (new_state, key), None
|
119
|
+
|
120
|
+
(final_state, _), _ = jax.lax.scan(f, (state, key), None, length=100)
|
121
|
+
return final_state
|
122
|
+
|
123
|
+
# warm-up compilation
|
124
|
+
key, run_key = jax.random.split(key)
|
125
|
+
_run(state, run_key).pos.block_until_ready()
|
126
|
+
|
127
|
+
def benchmark_fn():
|
128
|
+
# use a fixed key for benchmark consistency
|
129
|
+
key, run_key = jax.random.split(jax.random.key(1))
|
130
|
+
_run(state, run_key).pos.block_until_ready()
|
131
|
+
|
132
|
+
benchmark(benchmark_fn)
|
133
|
+
|
134
|
+
|
135
|
+
def test_benchmark_vmap_env(benchmark):
|
136
|
+
num_envs = 100
|
137
|
+
env = ForagaxObjectEnv(
|
138
|
+
size=1_000,
|
139
|
+
aperture_size=11,
|
140
|
+
objects=(WALL, FLOWER),
|
141
|
+
biomes=(Biome(object_frequencies=(0.1, 0.1)),),
|
142
|
+
)
|
143
|
+
params = env.default_params
|
144
|
+
key = jax.random.key(0)
|
145
|
+
|
146
|
+
# Reset is part of the setup, not benchmarked
|
147
|
+
key, reset_key = jax.random.split(key)
|
148
|
+
reset_keys = jax.random.split(reset_key, num_envs)
|
149
|
+
states = jax.vmap(env.reset, in_axes=(0, None))(reset_keys, params)[1]
|
150
|
+
|
151
|
+
@jax.jit
|
152
|
+
def _run(states, key):
|
153
|
+
def f(carry, _):
|
154
|
+
states, key = carry
|
155
|
+
key, step_key = jax.random.split(key, 2)
|
156
|
+
step_keys = jax.random.split(step_key, num_envs)
|
157
|
+
_, new_states, _, _, _ = jax.vmap(env.step, in_axes=(0, 0, None, None))(
|
158
|
+
step_keys, states, Actions.DOWN, params
|
159
|
+
)
|
160
|
+
return (new_states, key), None
|
161
|
+
|
162
|
+
(final_states, _), _ = jax.lax.scan(f, (states, key), None, length=1000)
|
163
|
+
return final_states
|
164
|
+
|
165
|
+
# warm-up compilation
|
166
|
+
key, run_key = jax.random.split(key)
|
167
|
+
_run(states, run_key).pos.block_until_ready()
|
168
|
+
|
169
|
+
def benchmark_fn():
|
170
|
+
# use a fixed key for benchmark consistency
|
171
|
+
key, run_key = jax.random.split(jax.random.key(1))
|
172
|
+
_run(states, run_key).pos.block_until_ready()
|
173
|
+
|
174
|
+
benchmark(benchmark_fn)
|
175
|
+
|
176
|
+
|
177
|
+
def test_benchmark_small_env_color(benchmark):
|
178
|
+
env = ForagaxRGBEnv(
|
179
|
+
size=1_000,
|
180
|
+
aperture_size=15,
|
181
|
+
objects=(WALL, FLOWER),
|
182
|
+
biomes=(Biome(object_frequencies=(0.05, 0.05)),),
|
183
|
+
)
|
184
|
+
params = env.default_params
|
185
|
+
key = jax.random.key(0)
|
186
|
+
key, reset_key = jax.random.split(key)
|
187
|
+
_, state = env.reset(reset_key, params)
|
188
|
+
|
189
|
+
@jax.jit
|
190
|
+
def _run(state, key):
|
191
|
+
def f(carry, _):
|
192
|
+
state, key = carry
|
193
|
+
key, step_key = jax.random.split(key, 2)
|
194
|
+
_, new_state, _, _, _ = env.step(step_key, state, Actions.DOWN, params)
|
195
|
+
return (new_state, key), None
|
196
|
+
|
197
|
+
(final_state, _), _ = jax.lax.scan(f, (state, key), None, length=100)
|
198
|
+
return final_state
|
199
|
+
|
200
|
+
key, run_key = jax.random.split(key)
|
201
|
+
_run(state, run_key).pos.block_until_ready()
|
202
|
+
|
203
|
+
def benchmark_fn():
|
204
|
+
key, run_key = jax.random.split(jax.random.key(1))
|
205
|
+
_run(state, run_key).pos.block_until_ready()
|
206
|
+
|
207
|
+
benchmark(benchmark_fn)
|
208
|
+
|
209
|
+
|
210
|
+
def test_benchmark_small_env_world(benchmark):
|
211
|
+
env = ForagaxWorldEnv(
|
212
|
+
size=1_000,
|
213
|
+
objects=(WALL, FLOWER),
|
214
|
+
biomes=(Biome(object_frequencies=(0.05, 0.05)),),
|
215
|
+
)
|
216
|
+
params = env.default_params
|
217
|
+
key = jax.random.key(0)
|
218
|
+
key, reset_key = jax.random.split(key)
|
219
|
+
_, state = env.reset(reset_key, params)
|
220
|
+
|
221
|
+
@jax.jit
|
222
|
+
def _run(state, key):
|
223
|
+
def f(carry, _):
|
224
|
+
state, key = carry
|
225
|
+
key, step_key = jax.random.split(key, 2)
|
226
|
+
_, new_state, _, _, _ = env.step(step_key, state, Actions.DOWN, params)
|
227
|
+
return (new_state, key), None
|
228
|
+
|
229
|
+
(final_state, _), _ = jax.lax.scan(f, (state, key), None, length=100)
|
230
|
+
return final_state
|
231
|
+
|
232
|
+
key, run_key = jax.random.split(key)
|
233
|
+
_run(state, run_key).pos.block_until_ready()
|
234
|
+
|
235
|
+
def benchmark_fn():
|
236
|
+
key, run_key = jax.random.split(jax.random.key(1))
|
237
|
+
_run(state, run_key).pos.block_until_ready()
|
238
|
+
|
239
|
+
benchmark(benchmark_fn)
|