rapidfireai 0.0.1__py3-none-any.whl → 0.9.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rapidfireai might be problematic. Click here for more details.
- rapidfireai/__init__.py +11 -5
- rapidfireai/automl/__init__.py +20 -0
- rapidfireai/automl/base.py +48 -0
- rapidfireai/automl/datatypes.py +42 -0
- rapidfireai/automl/grid_search.py +125 -0
- rapidfireai/automl/model_config.py +102 -0
- rapidfireai/automl/random_search.py +145 -0
- rapidfireai/backend/__init__.py +0 -0
- rapidfireai/backend/chunks.py +63 -0
- rapidfireai/backend/controller.py +637 -0
- rapidfireai/backend/scheduler.py +137 -0
- rapidfireai/backend/worker.py +272 -0
- rapidfireai/cli.py +380 -0
- rapidfireai/db/__init__.py +0 -0
- rapidfireai/db/db_interface.py +135 -0
- rapidfireai/db/rf_db.py +694 -0
- rapidfireai/db/tables.sql +64 -0
- rapidfireai/dispatcher/dispatcher.py +391 -0
- rapidfireai/dispatcher/gunicorn.conf.py +25 -0
- rapidfireai/experiment.py +168 -0
- rapidfireai/frontend/build/asset-manifest.json +276 -0
- rapidfireai/frontend/build/favicon.ico +0 -0
- rapidfireai/frontend/build/index.html +1 -0
- rapidfireai/frontend/build/manifest.json +15 -0
- rapidfireai/frontend/build/pdf.worker.js +1 -0
- rapidfireai/frontend/build/report.html +39 -0
- rapidfireai/frontend/build/static/css/1482.3b7bf531.chunk.css +1 -0
- rapidfireai/frontend/build/static/css/2730.3f8937ff.chunk.css +1 -0
- rapidfireai/frontend/build/static/css/318.0def90a7.css +7 -0
- rapidfireai/frontend/build/static/css/4762.9b7b71f7.chunk.css +1 -0
- rapidfireai/frontend/build/static/css/4950.487ecc8b.chunk.css +1 -0
- rapidfireai/frontend/build/static/css/5170.2574ce9d.chunk.css +1 -0
- rapidfireai/frontend/build/static/css/6121.4d541986.chunk.css +1 -0
- rapidfireai/frontend/build/static/css/6343.dd6979f2.chunk.css +1 -0
- rapidfireai/frontend/build/static/css/6534.433c213f.chunk.css +1 -0
- rapidfireai/frontend/build/static/css/6920.ffac4b2a.css +2 -0
- rapidfireai/frontend/build/static/css/7246.bf2f0c87.css +9 -0
- rapidfireai/frontend/build/static/css/7367.dd6979f2.chunk.css +1 -0
- rapidfireai/frontend/build/static/css/8690.05d081e5.chunk.css +1 -0
- rapidfireai/frontend/build/static/css/9531.d0910d3c.chunk.css +1 -0
- rapidfireai/frontend/build/static/css/9780.363e4943.chunk.css +1 -0
- rapidfireai/frontend/build/static/css/main~d91a9049.c0be472c.css +1 -0
- rapidfireai/frontend/build/static/js/1000.e5ed264b.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/1012.ac98ab59.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/1079.6c13ac0d.js +1 -0
- rapidfireai/frontend/build/static/js/110.9059f3b8.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/1142.872d0010.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/1167.9a6da14c.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/1248.60890b4f.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/1262.83dc7673.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/1273.56da3e13.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/1273.56da3e13.chunk.js.LICENSE.txt +9 -0
- rapidfireai/frontend/build/static/js/1303.7d19305c.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/1351.45076ff3.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/1355.b896a592.js +1 -0
- rapidfireai/frontend/build/static/js/1357.02c46a02.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/1470.c51d60c6.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/1482.23b74f50.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/1500.19799d8d.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/1648.d3b9edc7.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/1860.7d96e3f9.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/1909.5b1d9ff4.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/1928.44245110.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/1928.44245110.chunk.js.LICENSE.txt +11 -0
- rapidfireai/frontend/build/static/js/1933.deba26ca.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/21.aac92802.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/2103.0ca12071.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/2258.b3b8fab4.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/2289.9ad51e87.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/2323.7dd927d7.js +2 -0
- rapidfireai/frontend/build/static/js/2323.7dd927d7.js.LICENSE.txt +1 -0
- rapidfireai/frontend/build/static/js/2346.ed99ca72.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/2386.0a660834.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/2402.465048f9.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/243.5a83bbca.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/2589.68571e16.js +1 -0
- rapidfireai/frontend/build/static/js/2647.65092bab.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/2691.65d4a4e7.js +1 -0
- rapidfireai/frontend/build/static/js/2730.b38dd6f3.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/2746.ef752da4.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/2779.580d4491.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/2799.fe5993b2.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/2844.9708db79.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/2844.9708db79.chunk.js.LICENSE.txt +21 -0
- rapidfireai/frontend/build/static/js/2901.ee0c606b.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/2932.7cc0689b.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/2932.7cc0689b.chunk.js.LICENSE.txt +6 -0
- rapidfireai/frontend/build/static/js/2956.a393c8cc.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/2972.679bed05.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/2985.7e51cdfa.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/2985.7e51cdfa.chunk.js.LICENSE.txt +51 -0
- rapidfireai/frontend/build/static/js/3093.488df653.js +1 -0
- rapidfireai/frontend/build/static/js/3145.66ee61b9.js +1 -0
- rapidfireai/frontend/build/static/js/3170.a22f966a.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/3170.a22f966a.chunk.js.LICENSE.txt +21 -0
- rapidfireai/frontend/build/static/js/3307.f6fb258c.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/3325.d5b03d65.js +1 -0
- rapidfireai/frontend/build/static/js/3334.2d6704df.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/3334.2d6704df.chunk.js.LICENSE.txt +6 -0
- rapidfireai/frontend/build/static/js/3387.bb8edad3.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/3448.438e6579.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/3460.735eea87.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/3505.7fd3921a.js +2 -0
- rapidfireai/frontend/build/static/js/3505.7fd3921a.js.LICENSE.txt +9 -0
- rapidfireai/frontend/build/static/js/3510.cd167a00.js +2 -0
- rapidfireai/frontend/build/static/js/3510.cd167a00.js.LICENSE.txt +18 -0
- rapidfireai/frontend/build/static/js/3563.cc828e19.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/359.08960b84.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/359.08960b84.chunk.js.LICENSE.txt +4 -0
- rapidfireai/frontend/build/static/js/3608.403b4b79.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/3652.cb8add7f.js +1 -0
- rapidfireai/frontend/build/static/js/3775.5230b157.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/3817.53555d18.js +2 -0
- rapidfireai/frontend/build/static/js/3817.53555d18.js.LICENSE.txt +18 -0
- rapidfireai/frontend/build/static/js/3835.d9946ff9.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/3964.874f0297.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/3968.275cbc3d.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/3999.765cbd82.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/4020.4452c046.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/4138.2f6f6d9f.js +1 -0
- rapidfireai/frontend/build/static/js/4160.f424554c.js +1 -0
- rapidfireai/frontend/build/static/js/4180.50cea095.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/4221.b0bba3f5.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/4250.5bb49278.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/4297.15777d8f.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/4349.c965f2de.js +2 -0
- rapidfireai/frontend/build/static/js/4349.c965f2de.js.LICENSE.txt +1 -0
- rapidfireai/frontend/build/static/js/4484.4cbe5e7f.js +2 -0
- rapidfireai/frontend/build/static/js/4484.4cbe5e7f.js.LICENSE.txt +10 -0
- rapidfireai/frontend/build/static/js/4578.a8124588.js +1 -0
- rapidfireai/frontend/build/static/js/4596.89a97480.js +1 -0
- rapidfireai/frontend/build/static/js/4748.566f435a.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/4762.928e8a90.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/4768.7945be63.js +2 -0
- rapidfireai/frontend/build/static/js/4768.7945be63.js.LICENSE.txt +1 -0
- rapidfireai/frontend/build/static/js/4804.26b50dd4.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/4850.62390a45.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/4862.a0ccb221.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/491.5dc8ed40.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/492.9262f038.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/492.9262f038.chunk.js.LICENSE.txt +6 -0
- rapidfireai/frontend/build/static/js/4943.6d345fd3.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/4950.bc182e62.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/5042.d4f0c65a.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/5042.d4f0c65a.chunk.js.LICENSE.txt +6 -0
- rapidfireai/frontend/build/static/js/5170.0065e96f.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/5222.35c74a52.js +2 -0
- rapidfireai/frontend/build/static/js/5222.35c74a52.js.LICENSE.txt +10 -0
- rapidfireai/frontend/build/static/js/5223.3224f019.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/5223.3224f019.chunk.js.LICENSE.txt +3 -0
- rapidfireai/frontend/build/static/js/5229.7dd42316.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/5286.4c1ad26b.js +1 -0
- rapidfireai/frontend/build/static/js/5486.21cff711.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/5526.7b368956.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/5605.1ee4d87b.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/5682.40b42d8b.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/5794.9433d867.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/5826.38a56e8c.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/5826.38a56e8c.chunk.js.LICENSE.txt +1 -0
- rapidfireai/frontend/build/static/js/5862.50f42a0b.js +1 -0
- rapidfireai/frontend/build/static/js/5895.e26742f1.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/5919.edd4a5cf.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/598.a0e792ae.js +1 -0
- rapidfireai/frontend/build/static/js/6058.74162bf9.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/618.06051134.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/618.06051134.chunk.js.LICENSE.txt +21 -0
- rapidfireai/frontend/build/static/js/6335.9fca442d.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/6336.e05e1154.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/6343.2bcd28ff.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/6363.a319b8f2.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/6478.344abf25.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/6504.1c004564.js +1 -0
- rapidfireai/frontend/build/static/js/6534.ec7e149b.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/6715.55a5c19c.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/6756.e6cb993c.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/6756.e6cb993c.chunk.js.LICENSE.txt +10 -0
- rapidfireai/frontend/build/static/js/6762.acfde9fd.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/6762.acfde9fd.chunk.js.LICENSE.txt +19 -0
- rapidfireai/frontend/build/static/js/6846.67103d0e.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/6861.34cf0198.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/6899.0eaf36a8.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/6899.0eaf36a8.chunk.js.LICENSE.txt +5 -0
- rapidfireai/frontend/build/static/js/6933.8b564944.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/699.d0437920.js +1 -0
- rapidfireai/frontend/build/static/js/7076.4182f63a.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/7186.42ad86d5.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/7248.a46635fd.js +1 -0
- rapidfireai/frontend/build/static/js/725.6b15a14a.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/7266.3575539d.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/7270.0a1e84fc.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/7270.0a1e84fc.chunk.js.LICENSE.txt +6 -0
- rapidfireai/frontend/build/static/js/7367.7120474f.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/7436.8e226055.js +1 -0
- rapidfireai/frontend/build/static/js/7504.ef223844.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/7603.ee049fe3.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/7670.2835b49a.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/7670.2835b49a.chunk.js.LICENSE.txt +6 -0
- rapidfireai/frontend/build/static/js/7721.7390b3cc.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/7731.5796cced.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/775.660a5deb.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/775.660a5deb.chunk.js.LICENSE.txt +6 -0
- rapidfireai/frontend/build/static/js/7832.7976a3e4.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/7844.72cc2e81.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/7948.48eab032.js +1 -0
- rapidfireai/frontend/build/static/js/7972.085079d4.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/7972.085079d4.chunk.js.LICENSE.txt +6 -0
- rapidfireai/frontend/build/static/js/8017.a9e7dc5a.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/8023.75f1f3df.js +2 -0
- rapidfireai/frontend/build/static/js/8023.75f1f3df.js.LICENSE.txt +41 -0
- rapidfireai/frontend/build/static/js/8123.b69db974.js +1 -0
- rapidfireai/frontend/build/static/js/813.065a87e5.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/819.2056f122.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/819.2056f122.chunk.js.LICENSE.txt +6 -0
- rapidfireai/frontend/build/static/js/8262.04bc17d1.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/8300.75adcc4f.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/8336.b1d3e764.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/8365.26cf64ea.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/8398.8bca8e0e.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/8398.8bca8e0e.chunk.js.LICENSE.txt +6 -0
- rapidfireai/frontend/build/static/js/847.33ceed50.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/847.33ceed50.chunk.js.LICENSE.txt +6 -0
- rapidfireai/frontend/build/static/js/8486.8ec852a7.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/8497.19378265.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/8541.4c55c9f4.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/8690.e305a804.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/8690.e305a804.chunk.js.LICENSE.txt +6 -0
- rapidfireai/frontend/build/static/js/8712.a9445fe6.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/8763.61761e08.js +1 -0
- rapidfireai/frontend/build/static/js/8823.baf9bffd.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/8823.baf9bffd.chunk.js.LICENSE.txt +6 -0
- rapidfireai/frontend/build/static/js/8867.767462b7.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/8953.c0f88dea.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/8960.357cb1eb.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/8960.357cb1eb.chunk.js.LICENSE.txt +6 -0
- rapidfireai/frontend/build/static/js/9.f4492795.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/9.f4492795.chunk.js.LICENSE.txt +12 -0
- rapidfireai/frontend/build/static/js/9079.88a8d2a3.js +1 -0
- rapidfireai/frontend/build/static/js/9082.37c40520.chunk.js +10 -0
- rapidfireai/frontend/build/static/js/9133.90ae330d.js +2 -0
- rapidfireai/frontend/build/static/js/9133.90ae330d.js.LICENSE.txt +8 -0
- rapidfireai/frontend/build/static/js/9151.1ac359d5.js +2 -0
- rapidfireai/frontend/build/static/js/9151.1ac359d5.js.LICENSE.txt +8 -0
- rapidfireai/frontend/build/static/js/9168.027bf2fd.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/9194.9c5cc548.chunk.js +10 -0
- rapidfireai/frontend/build/static/js/9244.026f4aee.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/936.2e02d037.js +2 -0
- rapidfireai/frontend/build/static/js/936.2e02d037.js.LICENSE.txt +6 -0
- rapidfireai/frontend/build/static/js/9369.7d1a0a1d.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/9427.7c8442e7.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/944.55948859.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/9499.c53a82da.js +2 -0
- rapidfireai/frontend/build/static/js/9499.c53a82da.js.LICENSE.txt +62 -0
- rapidfireai/frontend/build/static/js/9531.3ce05781.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/9547.92fac952.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/9547.92fac952.chunk.js.LICENSE.txt +6 -0
- rapidfireai/frontend/build/static/js/9620.b6e973a7.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/9645.6fddfa65.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/9669.d38dda6d.js +1 -0
- rapidfireai/frontend/build/static/js/9682.41b6b807.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/9720.19d5ae76.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/9720.19d5ae76.chunk.js.LICENSE.txt +23 -0
- rapidfireai/frontend/build/static/js/9723.d3c7fe9e.js +1 -0
- rapidfireai/frontend/build/static/js/9780.02a27630.chunk.js +10 -0
- rapidfireai/frontend/build/static/js/9808.d0ca9674.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/9808.d0ca9674.chunk.js.LICENSE.txt +6 -0
- rapidfireai/frontend/build/static/js/9815.b8db3c5d.js +1 -0
- rapidfireai/frontend/build/static/js/9886.2940b53a.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/main~1f912138.fa9d03b1.js +1 -0
- rapidfireai/frontend/build/static/js/main~43dd7041.2e00860d.js +1 -0
- rapidfireai/frontend/build/static/js/main~84781932.68deffff.js +1 -0
- rapidfireai/frontend/build/static/media/404-overflow.fad9a31861b0afba6f921ebb8e769688.svg +32 -0
- rapidfireai/frontend/build/static/media/RapidFire_Square_Bug.27ceb48296314a4bc0d4.png +0 -0
- rapidfireai/frontend/build/static/media/chart-bar.0fd4a63680fba840a7b69fbf07969f79.svg +7 -0
- rapidfireai/frontend/build/static/media/chart-contour.0d4b306f2669f3ad25375568935e3ce3.svg +5 -0
- rapidfireai/frontend/build/static/media/chart-difference.16174216d6f3b7c24f40e3541fe0ca2c.svg +20 -0
- rapidfireai/frontend/build/static/media/chart-image.cc434c4dc50780966344e2385a15f8fe.svg +6 -0
- rapidfireai/frontend/build/static/media/chart-line.0adaa2036bb4eb5956db6d0c7e925a3d.svg +4 -0
- rapidfireai/frontend/build/static/media/chart-parallel.da7dedf539b2af4b654d377c679173e4.svg +7 -0
- rapidfireai/frontend/build/static/media/chart-scatter.69118d0023a6ff3973f7fa913834ac47.svg +9 -0
- rapidfireai/frontend/build/static/media/default-error.f246ddf367c6fbd67942e5a13382a7f1.svg +26 -0
- rapidfireai/frontend/build/static/media/fontawesome-webfont.1e59d2330b4c6deb84b3.ttf +0 -0
- rapidfireai/frontend/build/static/media/fontawesome-webfont.20fd1704ea223900efa9.woff2 +0 -0
- rapidfireai/frontend/build/static/media/fontawesome-webfont.8b43027f47b20503057d.eot +0 -0
- rapidfireai/frontend/build/static/media/fontawesome-webfont.c1e38fd9e0e74ba58f7a.svg +2671 -0
- rapidfireai/frontend/build/static/media/fontawesome-webfont.f691f37e57f04c152e23.woff +0 -0
- rapidfireai/frontend/build/static/media/icon-visible-fill.8d34cd35303828fdfc15154f5536e63b.svg +7 -0
- rapidfireai/frontend/build/static/media/no-experiments.0e4f4a114ef73e7d81c09474aba64b6c.svg +22 -0
- rapidfireai/frontend/build/static/media/parallel-chart-placeholder.234ef0c5b220ef2a5a6fa5bafff173f7.svg +16 -0
- rapidfireai/frontend/build/static/media/permission-denied-lock.16036747d57cd663d7df223781a447b2.svg +14 -0
- rapidfireai/frontend/build/static/media/promo-modal-content.e3b2c6c568ac192b9bec54b838b54850.svg +30 -0
- rapidfireai/frontend/build/static/media/registered-model-grey-ok.8274b58d39504c8d1b8c358aa1c9aa35.svg +23 -0
- rapidfireai/frontend/build/static/media/warning.290a3b14118933547965e91ea61c5a61.svg +3 -0
- rapidfireai/frontend/proxy_middleware.py +233 -0
- rapidfireai/frontend/server.py +25 -0
- rapidfireai/ml/__init__.py +0 -0
- rapidfireai/ml/callbacks.py +176 -0
- rapidfireai/ml/checkpoint_utils.py +540 -0
- rapidfireai/ml/trainer.py +309 -0
- rapidfireai/start.sh +634 -0
- rapidfireai/utils/__init__.py +0 -0
- rapidfireai/utils/automl_utils.py +51 -0
- rapidfireai/utils/constants.py +141 -0
- rapidfireai/utils/datapaths.py +69 -0
- rapidfireai/utils/exceptions.py +82 -0
- rapidfireai/utils/experiment_utils.py +370 -0
- rapidfireai/utils/logging.py +87 -0
- rapidfireai/utils/mlflow_manager.py +121 -0
- rapidfireai/utils/serialize.py +15 -0
- rapidfireai/utils/shm_manager.py +469 -0
- rapidfireai/utils/trainer_config.py +23 -0
- rapidfireai/utils/worker_manager.py +219 -0
- rapidfireai/version.py +6 -0
- rapidfireai-0.9.10.dist-info/METADATA +247 -0
- rapidfireai-0.9.10.dist-info/RECORD +318 -0
- rapidfireai-0.9.10.dist-info/entry_points.txt +2 -0
- rapidfireai-0.0.1.dist-info/METADATA +0 -37
- rapidfireai-0.0.1.dist-info/RECORD +0 -6
- {rapidfireai-0.0.1.dist-info → rapidfireai-0.9.10.dist-info}/WHEEL +0 -0
- {rapidfireai-0.0.1.dist-info → rapidfireai-0.9.10.dist-info}/licenses/LICENSE +0 -0
- {rapidfireai-0.0.1.dist-info → rapidfireai-0.9.10.dist-info}/top_level.txt +0 -0
rapidfireai/__init__.py
CHANGED
|
@@ -1,11 +1,17 @@
|
|
|
1
1
|
"""
|
|
2
|
-
RapidFire AI
|
|
2
|
+
RapidFire AI
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
__version__
|
|
6
|
-
|
|
7
|
-
|
|
5
|
+
from .version import __version__, __version_info__
|
|
6
|
+
|
|
7
|
+
__author__ = "RapidFire AI Inc."
|
|
8
|
+
__email__ = "support@rapidfire.ai"
|
|
9
|
+
|
|
10
|
+
from rapidfireai.experiment import Experiment
|
|
11
|
+
|
|
8
12
|
|
|
9
13
|
def coming_soon():
|
|
10
14
|
"""Placeholder function - full functionality coming soon."""
|
|
11
|
-
return "RapidFire AI package is under development. Stay tuned!"
|
|
15
|
+
return "RapidFire AI package is under development. Stay tuned!"
|
|
16
|
+
|
|
17
|
+
__all__ = ["Experiment"]
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
"""AutoML module for hyperparameter optimization."""
|
|
2
|
+
|
|
3
|
+
from .base import AutoMLAlgorithm
|
|
4
|
+
from .datatypes import List, Range
|
|
5
|
+
from .grid_search import RFGridSearch
|
|
6
|
+
from .model_config import RFDPOConfig, RFGRPOConfig, RFLoraConfig, RFModelConfig, RFSFTConfig
|
|
7
|
+
from .random_search import RFRandomSearch
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"List",
|
|
11
|
+
"Range",
|
|
12
|
+
"RFGridSearch",
|
|
13
|
+
"RFRandomSearch",
|
|
14
|
+
"AutoMLAlgorithm",
|
|
15
|
+
"RFModelConfig",
|
|
16
|
+
"RFLoraConfig",
|
|
17
|
+
"RFSFTConfig",
|
|
18
|
+
"RFDPOConfig",
|
|
19
|
+
"RFGRPOConfig",
|
|
20
|
+
]
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
"""Base classes and configurations for AutoML algorithms."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from rapidfireai.automl.datatypes import List
|
|
7
|
+
from rapidfireai.automl.model_config import RFModelConfig
|
|
8
|
+
from rapidfireai.utils.exceptions import AutoMLException
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class AutoMLAlgorithm(ABC):
|
|
12
|
+
"""Base class for AutoML algorithms."""
|
|
13
|
+
|
|
14
|
+
VALID_TRAINER_TYPES = {"SFT", "DPO", "GRPO"}
|
|
15
|
+
|
|
16
|
+
def __init__(self, configs=None, create_model_fn=None, trainer_type: str = "SFT", num_runs: int = 1):
|
|
17
|
+
"""Initialize AutoML algorithm with configurations and trainer type."""
|
|
18
|
+
try:
|
|
19
|
+
self.configs = self._normalize_configs(configs)
|
|
20
|
+
self.trainer_type = trainer_type.upper()
|
|
21
|
+
self.num_runs = num_runs
|
|
22
|
+
|
|
23
|
+
if self.trainer_type not in self.VALID_TRAINER_TYPES:
|
|
24
|
+
raise AutoMLException(f"trainer_type must be one of {self.VALID_TRAINER_TYPES}")
|
|
25
|
+
|
|
26
|
+
self._validate_configs()
|
|
27
|
+
except Exception as e:
|
|
28
|
+
raise AutoMLException(f"Error initializing {self.__class__.__name__}: {e}") from e
|
|
29
|
+
|
|
30
|
+
def _normalize_configs(self, configs):
|
|
31
|
+
"""Normalize configs to list format."""
|
|
32
|
+
if isinstance(configs, List):
|
|
33
|
+
return configs.values
|
|
34
|
+
elif isinstance(configs, list):
|
|
35
|
+
return configs
|
|
36
|
+
return [configs] if configs else []
|
|
37
|
+
|
|
38
|
+
def _validate_configs(self):
|
|
39
|
+
"""Validate all configs are RFModelConfig instances."""
|
|
40
|
+
for config in self.configs:
|
|
41
|
+
if not isinstance(config, RFModelConfig):
|
|
42
|
+
raise AutoMLException(f"All configs must be RFModelConfig instances, got {type(config)}")
|
|
43
|
+
|
|
44
|
+
@abstractmethod
|
|
45
|
+
def get_runs(self, seed: int) -> list[dict[str, Any]]:
|
|
46
|
+
"""Generate hyperparameter combinations for different training configurations."""
|
|
47
|
+
if not isinstance(seed, int) or seed < 0:
|
|
48
|
+
raise AutoMLException("seed must be a non-negative integer")
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
"""Contains classes for representing hyperparameter data types."""
|
|
2
|
+
|
|
3
|
+
import random
|
|
4
|
+
|
|
5
|
+
# TODO: need to set seed for random module.
|
|
6
|
+
# TODO: List.sample() will not work for nested lists.
|
|
7
|
+
# TODO: add support for sampling methods like 'uniform' and 'loguniform'.
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Range:
|
|
11
|
+
"""Represents a range of values for a hyperparameter."""
|
|
12
|
+
|
|
13
|
+
def __init__(self, start, end, dtype: str | None = None):
|
|
14
|
+
if dtype is None:
|
|
15
|
+
self.dtype = "int" if isinstance(start, int) and isinstance(end, int) else "float"
|
|
16
|
+
else:
|
|
17
|
+
if dtype not in ("int", "float"):
|
|
18
|
+
raise ValueError("dtype must be either 'int' or 'float'.")
|
|
19
|
+
self.dtype = dtype
|
|
20
|
+
if not (isinstance(start, (int, float)) and isinstance(end, (int, float))):
|
|
21
|
+
raise ValueError("start and end must be either int or float.")
|
|
22
|
+
self.start = start
|
|
23
|
+
self.end = end
|
|
24
|
+
|
|
25
|
+
def sample(self):
|
|
26
|
+
"""Sample a value from the range [self.start, self.end]."""
|
|
27
|
+
if self.dtype == "int":
|
|
28
|
+
return random.randint(self.start, self.end)
|
|
29
|
+
return random.uniform(self.start, self.end)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class List:
|
|
33
|
+
"""Represents a list of values for a hyperparameter."""
|
|
34
|
+
|
|
35
|
+
def __init__(self, values):
|
|
36
|
+
if not isinstance(values, list):
|
|
37
|
+
raise ValueError("List expects a list of values.")
|
|
38
|
+
self.values = values
|
|
39
|
+
|
|
40
|
+
def sample(self):
|
|
41
|
+
"""Sample a value from the list."""
|
|
42
|
+
return random.choice(self.values)
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
"""Grid search implementation for AutoML training configurations."""
|
|
2
|
+
|
|
3
|
+
from itertools import product
|
|
4
|
+
from typing import Any, Dict
|
|
5
|
+
from typing import List as ListType
|
|
6
|
+
|
|
7
|
+
from rapidfireai.automl.base import AutoMLAlgorithm
|
|
8
|
+
from rapidfireai.automl.datatypes import List
|
|
9
|
+
from rapidfireai.utils.exceptions import AutoMLException
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def recursive_expand_gridsearch(item: Any):
|
|
13
|
+
"""Recursively expand nested structures with List datatypes into all combinations."""
|
|
14
|
+
if isinstance(item, dict):
|
|
15
|
+
keys = list(item.keys())
|
|
16
|
+
value_lists = [list(recursive_expand_gridsearch(item[k])) for k in keys]
|
|
17
|
+
for values in product(*value_lists):
|
|
18
|
+
yield dict(zip(keys, values))
|
|
19
|
+
elif isinstance(item, List):
|
|
20
|
+
for value in item.values:
|
|
21
|
+
yield from recursive_expand_gridsearch(value)
|
|
22
|
+
else:
|
|
23
|
+
yield item
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class RFGridSearch(AutoMLAlgorithm):
|
|
27
|
+
"""Grid search algorithm that generates all hyperparameter combinations."""
|
|
28
|
+
|
|
29
|
+
def get_runs(self, seed: int) -> ListType[Dict[str, Any]]:
|
|
30
|
+
"""Generate all possible hyperparameter combinations for grid search."""
|
|
31
|
+
if not isinstance(seed, int) or seed < 0:
|
|
32
|
+
raise AutoMLException("seed must be a non-negative integer")
|
|
33
|
+
|
|
34
|
+
try:
|
|
35
|
+
runs = []
|
|
36
|
+
for config in self.configs:
|
|
37
|
+
if config.peft_config is None:
|
|
38
|
+
peft_configs = [None]
|
|
39
|
+
elif isinstance(config.peft_config, List):
|
|
40
|
+
peft_configs = config.peft_config.values
|
|
41
|
+
elif isinstance(config.peft_config, list):
|
|
42
|
+
peft_configs = config.peft_config
|
|
43
|
+
else:
|
|
44
|
+
peft_configs = [config.peft_config]
|
|
45
|
+
|
|
46
|
+
for peft_config in peft_configs:
|
|
47
|
+
peft_instances = (
|
|
48
|
+
[{}] if peft_config is None else list(recursive_expand_gridsearch(peft_config._user_params))
|
|
49
|
+
)
|
|
50
|
+
training_instances = (
|
|
51
|
+
[{}]
|
|
52
|
+
if config.training_args is None
|
|
53
|
+
else list(recursive_expand_gridsearch(config.training_args._user_params))
|
|
54
|
+
)
|
|
55
|
+
model_kwargs_instances = (
|
|
56
|
+
[{}] if config.model_kwargs is None else list(recursive_expand_gridsearch(config.model_kwargs))
|
|
57
|
+
)
|
|
58
|
+
ref_model_kwargs_instances = (
|
|
59
|
+
[{}]
|
|
60
|
+
if config.ref_model_kwargs is None
|
|
61
|
+
else list(recursive_expand_gridsearch(config.ref_model_kwargs))
|
|
62
|
+
)
|
|
63
|
+
reward_funcs_instances = (
|
|
64
|
+
[{}] if config.reward_funcs is None else list(recursive_expand_gridsearch(config.reward_funcs))
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
# Get additional kwargs for Trainer
|
|
68
|
+
# FIXME: this is a hack to get the additional kwargs, we should find a better way to do this
|
|
69
|
+
excluded_attrs = {
|
|
70
|
+
"model_name",
|
|
71
|
+
"tokenizer",
|
|
72
|
+
"tokenizer_kwargs",
|
|
73
|
+
"model_type",
|
|
74
|
+
"model_kwargs",
|
|
75
|
+
"peft_config",
|
|
76
|
+
"training_args",
|
|
77
|
+
"ref_model_name",
|
|
78
|
+
"ref_model_type",
|
|
79
|
+
"ref_model_kwargs",
|
|
80
|
+
"reward_funcs",
|
|
81
|
+
}
|
|
82
|
+
# excluded_attrs = set(config.__dict__.keys()) - set(config.__annotations__.keys())
|
|
83
|
+
additional_kwargs = {
|
|
84
|
+
k: v for k, v in config.__dict__.items() if k not in excluded_attrs and v is not None
|
|
85
|
+
}
|
|
86
|
+
additional_kwargs_instances = (
|
|
87
|
+
[{}] if not additional_kwargs else list(recursive_expand_gridsearch(additional_kwargs))
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
# Generate gridsearch combinations
|
|
91
|
+
for peft_params in peft_instances:
|
|
92
|
+
for training_params in training_instances:
|
|
93
|
+
for model_kwargs in model_kwargs_instances:
|
|
94
|
+
for additional_kwargs in additional_kwargs_instances:
|
|
95
|
+
leaf = {
|
|
96
|
+
"trainer_type": self.trainer_type,
|
|
97
|
+
"training_args": training_params,
|
|
98
|
+
"peft_params": peft_params,
|
|
99
|
+
"model_name": config.model_name,
|
|
100
|
+
"tokenizer": config.tokenizer,
|
|
101
|
+
"tokenizer_kwargs": config.tokenizer_kwargs,
|
|
102
|
+
"model_type": config.model_type,
|
|
103
|
+
"model_kwargs": model_kwargs,
|
|
104
|
+
"additional_kwargs": additional_kwargs,
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
if self.trainer_type == "DPO":
|
|
108
|
+
leaf["ref_model_config"] = {
|
|
109
|
+
"model_name": config.ref_model_name,
|
|
110
|
+
"model_type": config.ref_model_type,
|
|
111
|
+
}
|
|
112
|
+
for ref_model_kwargs in ref_model_kwargs_instances:
|
|
113
|
+
leaf["ref_model_config"]["model_kwargs"] = ref_model_kwargs
|
|
114
|
+
runs.append(leaf)
|
|
115
|
+
elif self.trainer_type == "GRPO":
|
|
116
|
+
for reward_func in reward_funcs_instances:
|
|
117
|
+
leaf["reward_funcs"] = reward_func
|
|
118
|
+
runs.append(leaf)
|
|
119
|
+
else:
|
|
120
|
+
runs.append(leaf)
|
|
121
|
+
|
|
122
|
+
return runs
|
|
123
|
+
|
|
124
|
+
except Exception as e:
|
|
125
|
+
raise AutoMLException(f"Error generating runs: {e}") from e
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
"""Model configuration for AutoML training."""
|
|
2
|
+
|
|
3
|
+
import inspect
|
|
4
|
+
import copy
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Any, Callable, Optional, Type, Union, get_type_hints
|
|
7
|
+
|
|
8
|
+
from peft import LoraConfig
|
|
9
|
+
from trl import DPOConfig, GRPOConfig, SFTConfig
|
|
10
|
+
|
|
11
|
+
from rapidfireai.automl.datatypes import List, Range
|
|
12
|
+
|
|
13
|
+
def _create_rf_class(base_class: Type, class_name: str):
|
|
14
|
+
"""Creating a RF class that dynamically inherits all constructor parameters and supports singleton, list, and Range values."""
|
|
15
|
+
if not inspect.isclass(base_class):
|
|
16
|
+
raise ValueError(f"base_class must be a class, got {type(base_class)}")
|
|
17
|
+
|
|
18
|
+
sig = inspect.signature(base_class.__init__)
|
|
19
|
+
constructor_params = [p for p in sig.parameters.keys() if p != "self"]
|
|
20
|
+
|
|
21
|
+
type_hints = get_type_hints(base_class)
|
|
22
|
+
new_type_hints = {}
|
|
23
|
+
|
|
24
|
+
for param_name, param_type in type_hints.items():
|
|
25
|
+
if param_name in constructor_params:
|
|
26
|
+
new_type_hints[param_name] = param_type | List | Range
|
|
27
|
+
|
|
28
|
+
def __init__(self, **kwargs):
|
|
29
|
+
self._user_params = copy.deepcopy(kwargs)
|
|
30
|
+
self._constructor_params = constructor_params
|
|
31
|
+
self._initializing = True
|
|
32
|
+
|
|
33
|
+
parent_kwargs = {}
|
|
34
|
+
for key, value in kwargs.items():
|
|
35
|
+
if not isinstance(value, (List, Range)):
|
|
36
|
+
parent_kwargs[key] = value
|
|
37
|
+
|
|
38
|
+
base_class.__init__(self, **parent_kwargs)
|
|
39
|
+
|
|
40
|
+
self._initializing = False
|
|
41
|
+
def copy_config(self):
|
|
42
|
+
"""Create a deep copy of the configuration."""
|
|
43
|
+
copied_params = copy.deepcopy(self._user_params)
|
|
44
|
+
new_instance = self.__class__(**copied_params)
|
|
45
|
+
|
|
46
|
+
return new_instance
|
|
47
|
+
|
|
48
|
+
def __setattr__(self, name, value):
|
|
49
|
+
"""Override setattr to update _user_params when constructor parameters are modified."""
|
|
50
|
+
|
|
51
|
+
if (hasattr(self, '_constructor_params') and
|
|
52
|
+
name in self._constructor_params and
|
|
53
|
+
hasattr(self, '_user_params') and
|
|
54
|
+
name in self._user_params and
|
|
55
|
+
not getattr(self, '_initializing', True)): # Don't update during init
|
|
56
|
+
self._user_params[name] = value
|
|
57
|
+
|
|
58
|
+
base_class.__setattr__(self, name, value)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
return type(
|
|
62
|
+
class_name,
|
|
63
|
+
(base_class,),
|
|
64
|
+
{
|
|
65
|
+
"__doc__": f"RF version of {base_class.__name__}",
|
|
66
|
+
"__annotations__": new_type_hints,
|
|
67
|
+
"__init__": __init__,
|
|
68
|
+
"copy": copy_config,
|
|
69
|
+
"__setattr__": __setattr__
|
|
70
|
+
},
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
# Create RF wrapper classes for external libraries
|
|
75
|
+
RFLoraConfig = _create_rf_class(LoraConfig, "RFLoraConfig")
|
|
76
|
+
RFSFTConfig = _create_rf_class(SFTConfig, "RFSFTConfig")
|
|
77
|
+
RFDPOConfig = _create_rf_class(DPOConfig, "RFDPOConfig")
|
|
78
|
+
RFGRPOConfig = _create_rf_class(GRPOConfig, "RFGRPOConfig")
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@dataclass
|
|
82
|
+
class RFModelConfig:
|
|
83
|
+
"""Model configuration for AutoML training."""
|
|
84
|
+
|
|
85
|
+
model_name: str = None
|
|
86
|
+
tokenizer: Optional[str] = None
|
|
87
|
+
tokenizer_kwargs: Optional[dict[str, Any]] = None
|
|
88
|
+
formatting_func: Optional[Union[Callable, List]] = None
|
|
89
|
+
compute_metrics: Optional[Union[Callable, List]] = None
|
|
90
|
+
peft_config: Optional[Union[RFLoraConfig, List]] = None
|
|
91
|
+
training_args: Optional[Union[RFSFTConfig, RFDPOConfig, RFGRPOConfig]] = None
|
|
92
|
+
model_type: Optional[str] = "causal_lm"
|
|
93
|
+
model_kwargs: Optional[dict[str, Any]] = None
|
|
94
|
+
ref_model_name: Optional[str] = None
|
|
95
|
+
ref_model_type: Optional[str] = None
|
|
96
|
+
ref_model_kwargs: Optional[dict[str, Any]] = None
|
|
97
|
+
reward_funcs: Optional[Union[str, List, Callable, Any]] = None
|
|
98
|
+
generation_config: Optional[dict[str, Any]] = None
|
|
99
|
+
|
|
100
|
+
def copy(self):#FIXME: Handle similar to create_rf_class
|
|
101
|
+
"""Create a deep copy of the RFModelConfig."""
|
|
102
|
+
return copy.deepcopy(self)
|
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
"""Random search implementation for AutoML hyperparameter optimization."""
|
|
2
|
+
|
|
3
|
+
import random
|
|
4
|
+
import json
|
|
5
|
+
from itertools import product
|
|
6
|
+
from typing import Any, Dict
|
|
7
|
+
from typing import List as ListType
|
|
8
|
+
|
|
9
|
+
from rapidfireai.automl.base import AutoMLAlgorithm
|
|
10
|
+
from rapidfireai.automl.datatypes import List, Range
|
|
11
|
+
from rapidfireai.utils.exceptions import AutoMLException
|
|
12
|
+
from rapidfireai.utils.serialize import encode_payload
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def recursive_expand_randomsearch(item: Any):
|
|
16
|
+
if isinstance(item, dict):
|
|
17
|
+
return {k: recursive_expand_randomsearch(v) for k, v in item.items()}
|
|
18
|
+
elif isinstance(item, List):
|
|
19
|
+
return item.sample()
|
|
20
|
+
elif isinstance(item, Range):
|
|
21
|
+
return item.sample()
|
|
22
|
+
else:
|
|
23
|
+
return item
|
|
24
|
+
|
|
25
|
+
class RFRandomSearch(AutoMLAlgorithm):
|
|
26
|
+
"""Random search algorithm that samples num_runs hyperparameter combinations."""
|
|
27
|
+
|
|
28
|
+
def get_runs(self, seed: int=42) -> ListType[Dict[str, Any]]:
|
|
29
|
+
"""Generate num_runs random hyperparameter combinations."""
|
|
30
|
+
if seed is not None and (not isinstance(seed, int) or seed < 0):
|
|
31
|
+
raise AutoMLException("seed must be a non-negative integer")
|
|
32
|
+
|
|
33
|
+
if not isinstance(self.num_runs, int) or self.num_runs <= 0:
|
|
34
|
+
raise AutoMLException("num_runs must be a positive integer")
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
random.seed(seed)
|
|
38
|
+
|
|
39
|
+
try:
|
|
40
|
+
runs = []
|
|
41
|
+
seen_configs = set()
|
|
42
|
+
max_attempts = self.num_runs * 10
|
|
43
|
+
attempts = 0
|
|
44
|
+
|
|
45
|
+
while len(runs) < self.num_runs and attempts < max_attempts:
|
|
46
|
+
attempts += 1
|
|
47
|
+
|
|
48
|
+
config = List(self.configs).sample()
|
|
49
|
+
|
|
50
|
+
if config.peft_config is None:
|
|
51
|
+
selected_peft_config = None
|
|
52
|
+
elif isinstance(config.peft_config, list):
|
|
53
|
+
selected_peft_config = List(config.peft_config).sample()
|
|
54
|
+
elif isinstance(config.peft_config, List):
|
|
55
|
+
selected_peft_config = config.peft_config.sample()
|
|
56
|
+
else:
|
|
57
|
+
selected_peft_config = config.peft_config
|
|
58
|
+
|
|
59
|
+
peft_params = (
|
|
60
|
+
{} if selected_peft_config is None
|
|
61
|
+
else recursive_expand_randomsearch(selected_peft_config._user_params)
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
# Sample other parameters
|
|
66
|
+
training_params = (
|
|
67
|
+
{} if config.training_args is None
|
|
68
|
+
else recursive_expand_randomsearch(config.training_args._user_params)
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
model_kwargs = (
|
|
72
|
+
{} if config.model_kwargs is None
|
|
73
|
+
else recursive_expand_randomsearch(config.model_kwargs)
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
ref_model_kwargs = (
|
|
77
|
+
{} if config.ref_model_kwargs is None
|
|
78
|
+
else recursive_expand_randomsearch(config.ref_model_kwargs)
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
reward_funcs = (
|
|
82
|
+
{} if config.reward_funcs is None
|
|
83
|
+
else recursive_expand_randomsearch(config.reward_funcs)
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
# FIXME: avoid hardcoding the excluded attributes
|
|
87
|
+
excluded_attrs = {
|
|
88
|
+
"model_name",
|
|
89
|
+
"tokenizer",
|
|
90
|
+
"tokenizer_kwargs",
|
|
91
|
+
"model_type",
|
|
92
|
+
"model_kwargs",
|
|
93
|
+
"peft_config",
|
|
94
|
+
"training_args",
|
|
95
|
+
"ref_model_name",
|
|
96
|
+
"ref_model_type",
|
|
97
|
+
"ref_model_kwargs",
|
|
98
|
+
"reward_funcs",
|
|
99
|
+
}
|
|
100
|
+
additional_kwargs = {
|
|
101
|
+
k: v for k, v in config.__dict__.items() if k not in excluded_attrs and v is not None
|
|
102
|
+
}
|
|
103
|
+
additional_kwargs_sampled = (
|
|
104
|
+
{} if not additional_kwargs
|
|
105
|
+
else recursive_expand_randomsearch(additional_kwargs)
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
leaf = {
|
|
109
|
+
"trainer_type": self.trainer_type,
|
|
110
|
+
"training_args": training_params,
|
|
111
|
+
"peft_params": peft_params,
|
|
112
|
+
"model_name": config.model_name,
|
|
113
|
+
"tokenizer": config.tokenizer,
|
|
114
|
+
"tokenizer_kwargs": config.tokenizer_kwargs,
|
|
115
|
+
"model_type": config.model_type,
|
|
116
|
+
"model_kwargs": model_kwargs,
|
|
117
|
+
"additional_kwargs": additional_kwargs_sampled,
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
if self.trainer_type == "DPO":
|
|
121
|
+
leaf["ref_model_config"] = {
|
|
122
|
+
"model_name": config.ref_model_name,
|
|
123
|
+
"model_type": config.ref_model_type,
|
|
124
|
+
"model_kwargs": ref_model_kwargs,
|
|
125
|
+
}
|
|
126
|
+
#FIXME: correct ref args
|
|
127
|
+
elif self.trainer_type == "GRPO":
|
|
128
|
+
leaf["reward_funcs"] = reward_funcs
|
|
129
|
+
|
|
130
|
+
# Check for duplicates using hashable representation
|
|
131
|
+
config_hash = encode_payload(leaf)
|
|
132
|
+
if config_hash not in seen_configs:
|
|
133
|
+
seen_configs.add(config_hash)
|
|
134
|
+
runs.append(leaf)
|
|
135
|
+
|
|
136
|
+
if len(runs) < self.num_runs:
|
|
137
|
+
raise AutoMLException(
|
|
138
|
+
f"Could not generate {self.num_runs} unique configurations. "
|
|
139
|
+
f"Generated {len(runs)} unique configs after {attempts} attempts. "
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
return runs
|
|
143
|
+
|
|
144
|
+
except Exception as e:
|
|
145
|
+
raise AutoMLException(f"Error generating runs: {e}") from e
|
|
File without changes
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
"""This module contains the DatasetChunker class which is responsible for chunking a PyTorch Dataset
|
|
2
|
+
into chunks for distributed processing."""
|
|
3
|
+
|
|
4
|
+
from datasets import Dataset
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class DatasetChunks:
|
|
8
|
+
"""Chunks a HuggingFace Dataset into n_chunks for distributed processing."""
|
|
9
|
+
|
|
10
|
+
def __init__(self, dataset: Dataset, n_chunks: int):
|
|
11
|
+
self.dataset = dataset
|
|
12
|
+
self.n_chunks = n_chunks
|
|
13
|
+
self.dataset_size = len(dataset)
|
|
14
|
+
|
|
15
|
+
# Validate n_chunks
|
|
16
|
+
if n_chunks <= 0:
|
|
17
|
+
raise ValueError(f"n_chunks must be positive, got {n_chunks}")
|
|
18
|
+
|
|
19
|
+
# Calculate base size for even distribution (not chunk_size anymore)
|
|
20
|
+
self.base_size = self.dataset_size // n_chunks
|
|
21
|
+
self.extra_items = self.dataset_size % n_chunks
|
|
22
|
+
self.chunk_indices = self._create_chunk_indices()
|
|
23
|
+
|
|
24
|
+
def _create_chunk_indices(self):
|
|
25
|
+
"""Create start/end index pairs for each chunk, distributing items as evenly as possible."""
|
|
26
|
+
chunks = {}
|
|
27
|
+
|
|
28
|
+
# Calculate base size and number of chunks that get an extra item
|
|
29
|
+
base_size = self.dataset_size // self.n_chunks
|
|
30
|
+
extra_items = self.dataset_size % self.n_chunks
|
|
31
|
+
|
|
32
|
+
current_idx = 0
|
|
33
|
+
for chunk_id in range(self.n_chunks):
|
|
34
|
+
# First 'extra_items' chunks get base_size + 1, rest get base_size
|
|
35
|
+
chunk_size = base_size + (1 if chunk_id < extra_items else 0)
|
|
36
|
+
|
|
37
|
+
if chunk_size > 0: # Only create non-empty chunks
|
|
38
|
+
chunks[chunk_id] = (current_idx, current_idx + chunk_size)
|
|
39
|
+
current_idx += chunk_size
|
|
40
|
+
|
|
41
|
+
return chunks
|
|
42
|
+
|
|
43
|
+
def get_chunk(self, chunk_id: int) -> Dataset:
|
|
44
|
+
"""Get a chunk as a HuggingFace Dataset subset."""
|
|
45
|
+
if chunk_id not in self.chunk_indices:
|
|
46
|
+
raise ValueError(f"Invalid chunk_id {chunk_id}. Valid range: 0-{len(self.chunk_indices) - 1}")
|
|
47
|
+
|
|
48
|
+
start_idx, end_idx = self.chunk_indices[chunk_id]
|
|
49
|
+
# Use HuggingFace Dataset's select method to create a subset
|
|
50
|
+
indices = list(range(start_idx, end_idx))
|
|
51
|
+
return self.dataset.select(indices)
|
|
52
|
+
|
|
53
|
+
def get_chunk_size(self, chunk_id: int) -> int:
|
|
54
|
+
"""Get the size of a specific chunk."""
|
|
55
|
+
if chunk_id not in self.chunk_indices:
|
|
56
|
+
raise ValueError(f"Invalid chunk_id {chunk_id}")
|
|
57
|
+
start_idx, end_idx = self.chunk_indices[chunk_id]
|
|
58
|
+
return end_idx - start_idx
|
|
59
|
+
|
|
60
|
+
@property
|
|
61
|
+
def chunk_ids(self):
|
|
62
|
+
"""Get all available chunk IDs."""
|
|
63
|
+
return list(self.chunk_indices.keys())
|