rapidfireai 0.0.1__py3-none-any.whl → 0.9.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rapidfireai might be problematic. Click here for more details.
- rapidfireai/__init__.py +11 -5
- rapidfireai/automl/__init__.py +20 -0
- rapidfireai/automl/base.py +48 -0
- rapidfireai/automl/datatypes.py +42 -0
- rapidfireai/automl/grid_search.py +125 -0
- rapidfireai/automl/model_config.py +102 -0
- rapidfireai/automl/random_search.py +145 -0
- rapidfireai/backend/__init__.py +0 -0
- rapidfireai/backend/chunks.py +63 -0
- rapidfireai/backend/controller.py +637 -0
- rapidfireai/backend/scheduler.py +137 -0
- rapidfireai/backend/worker.py +272 -0
- rapidfireai/cli.py +380 -0
- rapidfireai/db/__init__.py +0 -0
- rapidfireai/db/db_interface.py +135 -0
- rapidfireai/db/rf_db.py +694 -0
- rapidfireai/db/tables.sql +64 -0
- rapidfireai/dispatcher/dispatcher.py +391 -0
- rapidfireai/dispatcher/gunicorn.conf.py +25 -0
- rapidfireai/experiment.py +168 -0
- rapidfireai/frontend/build/asset-manifest.json +276 -0
- rapidfireai/frontend/build/favicon.ico +0 -0
- rapidfireai/frontend/build/index.html +1 -0
- rapidfireai/frontend/build/manifest.json +15 -0
- rapidfireai/frontend/build/pdf.worker.js +1 -0
- rapidfireai/frontend/build/report.html +39 -0
- rapidfireai/frontend/build/static/css/1482.3b7bf531.chunk.css +1 -0
- rapidfireai/frontend/build/static/css/2730.3f8937ff.chunk.css +1 -0
- rapidfireai/frontend/build/static/css/318.0def90a7.css +7 -0
- rapidfireai/frontend/build/static/css/4762.9b7b71f7.chunk.css +1 -0
- rapidfireai/frontend/build/static/css/4950.487ecc8b.chunk.css +1 -0
- rapidfireai/frontend/build/static/css/5170.2574ce9d.chunk.css +1 -0
- rapidfireai/frontend/build/static/css/6121.4d541986.chunk.css +1 -0
- rapidfireai/frontend/build/static/css/6343.dd6979f2.chunk.css +1 -0
- rapidfireai/frontend/build/static/css/6534.433c213f.chunk.css +1 -0
- rapidfireai/frontend/build/static/css/6920.ffac4b2a.css +2 -0
- rapidfireai/frontend/build/static/css/7246.bf2f0c87.css +9 -0
- rapidfireai/frontend/build/static/css/7367.dd6979f2.chunk.css +1 -0
- rapidfireai/frontend/build/static/css/8690.05d081e5.chunk.css +1 -0
- rapidfireai/frontend/build/static/css/9531.d0910d3c.chunk.css +1 -0
- rapidfireai/frontend/build/static/css/9780.363e4943.chunk.css +1 -0
- rapidfireai/frontend/build/static/css/main~d91a9049.c0be472c.css +1 -0
- rapidfireai/frontend/build/static/js/1000.e5ed264b.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/1012.ac98ab59.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/1079.6c13ac0d.js +1 -0
- rapidfireai/frontend/build/static/js/110.9059f3b8.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/1142.872d0010.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/1167.9a6da14c.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/1248.60890b4f.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/1262.83dc7673.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/1273.56da3e13.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/1273.56da3e13.chunk.js.LICENSE.txt +9 -0
- rapidfireai/frontend/build/static/js/1303.7d19305c.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/1351.45076ff3.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/1355.b896a592.js +1 -0
- rapidfireai/frontend/build/static/js/1357.02c46a02.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/1470.c51d60c6.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/1482.23b74f50.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/1500.19799d8d.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/1648.d3b9edc7.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/1860.7d96e3f9.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/1909.5b1d9ff4.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/1928.44245110.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/1928.44245110.chunk.js.LICENSE.txt +11 -0
- rapidfireai/frontend/build/static/js/1933.deba26ca.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/21.aac92802.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/2103.0ca12071.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/2258.b3b8fab4.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/2289.9ad51e87.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/2323.7dd927d7.js +2 -0
- rapidfireai/frontend/build/static/js/2323.7dd927d7.js.LICENSE.txt +1 -0
- rapidfireai/frontend/build/static/js/2346.ed99ca72.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/2386.0a660834.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/2402.465048f9.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/243.5a83bbca.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/2589.68571e16.js +1 -0
- rapidfireai/frontend/build/static/js/2647.65092bab.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/2691.65d4a4e7.js +1 -0
- rapidfireai/frontend/build/static/js/2730.b38dd6f3.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/2746.ef752da4.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/2779.580d4491.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/2799.fe5993b2.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/2844.9708db79.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/2844.9708db79.chunk.js.LICENSE.txt +21 -0
- rapidfireai/frontend/build/static/js/2901.ee0c606b.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/2932.7cc0689b.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/2932.7cc0689b.chunk.js.LICENSE.txt +6 -0
- rapidfireai/frontend/build/static/js/2956.a393c8cc.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/2972.679bed05.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/2985.7e51cdfa.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/2985.7e51cdfa.chunk.js.LICENSE.txt +51 -0
- rapidfireai/frontend/build/static/js/3093.488df653.js +1 -0
- rapidfireai/frontend/build/static/js/3145.66ee61b9.js +1 -0
- rapidfireai/frontend/build/static/js/3170.a22f966a.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/3170.a22f966a.chunk.js.LICENSE.txt +21 -0
- rapidfireai/frontend/build/static/js/3307.f6fb258c.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/3325.d5b03d65.js +1 -0
- rapidfireai/frontend/build/static/js/3334.2d6704df.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/3334.2d6704df.chunk.js.LICENSE.txt +6 -0
- rapidfireai/frontend/build/static/js/3387.bb8edad3.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/3448.438e6579.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/3460.735eea87.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/3505.7fd3921a.js +2 -0
- rapidfireai/frontend/build/static/js/3505.7fd3921a.js.LICENSE.txt +9 -0
- rapidfireai/frontend/build/static/js/3510.cd167a00.js +2 -0
- rapidfireai/frontend/build/static/js/3510.cd167a00.js.LICENSE.txt +18 -0
- rapidfireai/frontend/build/static/js/3563.cc828e19.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/359.08960b84.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/359.08960b84.chunk.js.LICENSE.txt +4 -0
- rapidfireai/frontend/build/static/js/3608.403b4b79.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/3652.cb8add7f.js +1 -0
- rapidfireai/frontend/build/static/js/3775.5230b157.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/3817.53555d18.js +2 -0
- rapidfireai/frontend/build/static/js/3817.53555d18.js.LICENSE.txt +18 -0
- rapidfireai/frontend/build/static/js/3835.d9946ff9.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/3964.874f0297.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/3968.275cbc3d.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/3999.765cbd82.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/4020.4452c046.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/4138.2f6f6d9f.js +1 -0
- rapidfireai/frontend/build/static/js/4160.f424554c.js +1 -0
- rapidfireai/frontend/build/static/js/4180.50cea095.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/4221.b0bba3f5.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/4250.5bb49278.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/4297.15777d8f.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/4349.c965f2de.js +2 -0
- rapidfireai/frontend/build/static/js/4349.c965f2de.js.LICENSE.txt +1 -0
- rapidfireai/frontend/build/static/js/4484.4cbe5e7f.js +2 -0
- rapidfireai/frontend/build/static/js/4484.4cbe5e7f.js.LICENSE.txt +10 -0
- rapidfireai/frontend/build/static/js/4578.a8124588.js +1 -0
- rapidfireai/frontend/build/static/js/4596.89a97480.js +1 -0
- rapidfireai/frontend/build/static/js/4748.566f435a.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/4762.928e8a90.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/4768.7945be63.js +2 -0
- rapidfireai/frontend/build/static/js/4768.7945be63.js.LICENSE.txt +1 -0
- rapidfireai/frontend/build/static/js/4804.26b50dd4.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/4850.62390a45.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/4862.a0ccb221.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/491.5dc8ed40.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/492.9262f038.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/492.9262f038.chunk.js.LICENSE.txt +6 -0
- rapidfireai/frontend/build/static/js/4943.6d345fd3.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/4950.bc182e62.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/5042.d4f0c65a.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/5042.d4f0c65a.chunk.js.LICENSE.txt +6 -0
- rapidfireai/frontend/build/static/js/5170.0065e96f.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/5222.35c74a52.js +2 -0
- rapidfireai/frontend/build/static/js/5222.35c74a52.js.LICENSE.txt +10 -0
- rapidfireai/frontend/build/static/js/5223.3224f019.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/5223.3224f019.chunk.js.LICENSE.txt +3 -0
- rapidfireai/frontend/build/static/js/5229.7dd42316.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/5286.4c1ad26b.js +1 -0
- rapidfireai/frontend/build/static/js/5486.21cff711.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/5526.7b368956.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/5605.1ee4d87b.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/5682.40b42d8b.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/5794.9433d867.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/5826.38a56e8c.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/5826.38a56e8c.chunk.js.LICENSE.txt +1 -0
- rapidfireai/frontend/build/static/js/5862.50f42a0b.js +1 -0
- rapidfireai/frontend/build/static/js/5895.e26742f1.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/5919.edd4a5cf.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/598.a0e792ae.js +1 -0
- rapidfireai/frontend/build/static/js/6058.74162bf9.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/618.06051134.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/618.06051134.chunk.js.LICENSE.txt +21 -0
- rapidfireai/frontend/build/static/js/6335.9fca442d.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/6336.e05e1154.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/6343.2bcd28ff.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/6363.a319b8f2.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/6478.344abf25.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/6504.1c004564.js +1 -0
- rapidfireai/frontend/build/static/js/6534.ec7e149b.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/6715.55a5c19c.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/6756.e6cb993c.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/6756.e6cb993c.chunk.js.LICENSE.txt +10 -0
- rapidfireai/frontend/build/static/js/6762.acfde9fd.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/6762.acfde9fd.chunk.js.LICENSE.txt +19 -0
- rapidfireai/frontend/build/static/js/6846.67103d0e.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/6861.34cf0198.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/6899.0eaf36a8.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/6899.0eaf36a8.chunk.js.LICENSE.txt +5 -0
- rapidfireai/frontend/build/static/js/6933.8b564944.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/699.d0437920.js +1 -0
- rapidfireai/frontend/build/static/js/7076.4182f63a.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/7186.42ad86d5.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/7248.a46635fd.js +1 -0
- rapidfireai/frontend/build/static/js/725.6b15a14a.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/7266.3575539d.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/7270.0a1e84fc.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/7270.0a1e84fc.chunk.js.LICENSE.txt +6 -0
- rapidfireai/frontend/build/static/js/7367.7120474f.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/7436.8e226055.js +1 -0
- rapidfireai/frontend/build/static/js/7504.ef223844.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/7603.ee049fe3.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/7670.2835b49a.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/7670.2835b49a.chunk.js.LICENSE.txt +6 -0
- rapidfireai/frontend/build/static/js/7721.7390b3cc.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/7731.5796cced.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/775.660a5deb.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/775.660a5deb.chunk.js.LICENSE.txt +6 -0
- rapidfireai/frontend/build/static/js/7832.7976a3e4.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/7844.72cc2e81.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/7948.48eab032.js +1 -0
- rapidfireai/frontend/build/static/js/7972.085079d4.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/7972.085079d4.chunk.js.LICENSE.txt +6 -0
- rapidfireai/frontend/build/static/js/8017.a9e7dc5a.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/8023.75f1f3df.js +2 -0
- rapidfireai/frontend/build/static/js/8023.75f1f3df.js.LICENSE.txt +41 -0
- rapidfireai/frontend/build/static/js/8123.b69db974.js +1 -0
- rapidfireai/frontend/build/static/js/813.065a87e5.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/819.2056f122.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/819.2056f122.chunk.js.LICENSE.txt +6 -0
- rapidfireai/frontend/build/static/js/8262.04bc17d1.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/8300.75adcc4f.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/8336.b1d3e764.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/8365.26cf64ea.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/8398.8bca8e0e.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/8398.8bca8e0e.chunk.js.LICENSE.txt +6 -0
- rapidfireai/frontend/build/static/js/847.33ceed50.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/847.33ceed50.chunk.js.LICENSE.txt +6 -0
- rapidfireai/frontend/build/static/js/8486.8ec852a7.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/8497.19378265.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/8541.4c55c9f4.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/8690.e305a804.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/8690.e305a804.chunk.js.LICENSE.txt +6 -0
- rapidfireai/frontend/build/static/js/8712.a9445fe6.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/8763.61761e08.js +1 -0
- rapidfireai/frontend/build/static/js/8823.baf9bffd.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/8823.baf9bffd.chunk.js.LICENSE.txt +6 -0
- rapidfireai/frontend/build/static/js/8867.767462b7.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/8953.c0f88dea.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/8960.357cb1eb.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/8960.357cb1eb.chunk.js.LICENSE.txt +6 -0
- rapidfireai/frontend/build/static/js/9.f4492795.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/9.f4492795.chunk.js.LICENSE.txt +12 -0
- rapidfireai/frontend/build/static/js/9079.88a8d2a3.js +1 -0
- rapidfireai/frontend/build/static/js/9082.37c40520.chunk.js +10 -0
- rapidfireai/frontend/build/static/js/9133.90ae330d.js +2 -0
- rapidfireai/frontend/build/static/js/9133.90ae330d.js.LICENSE.txt +8 -0
- rapidfireai/frontend/build/static/js/9151.1ac359d5.js +2 -0
- rapidfireai/frontend/build/static/js/9151.1ac359d5.js.LICENSE.txt +8 -0
- rapidfireai/frontend/build/static/js/9168.027bf2fd.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/9194.9c5cc548.chunk.js +10 -0
- rapidfireai/frontend/build/static/js/9244.026f4aee.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/936.2e02d037.js +2 -0
- rapidfireai/frontend/build/static/js/936.2e02d037.js.LICENSE.txt +6 -0
- rapidfireai/frontend/build/static/js/9369.7d1a0a1d.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/9427.7c8442e7.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/944.55948859.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/9499.c53a82da.js +2 -0
- rapidfireai/frontend/build/static/js/9499.c53a82da.js.LICENSE.txt +62 -0
- rapidfireai/frontend/build/static/js/9531.3ce05781.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/9547.92fac952.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/9547.92fac952.chunk.js.LICENSE.txt +6 -0
- rapidfireai/frontend/build/static/js/9620.b6e973a7.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/9645.6fddfa65.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/9669.d38dda6d.js +1 -0
- rapidfireai/frontend/build/static/js/9682.41b6b807.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/9720.19d5ae76.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/9720.19d5ae76.chunk.js.LICENSE.txt +23 -0
- rapidfireai/frontend/build/static/js/9723.d3c7fe9e.js +1 -0
- rapidfireai/frontend/build/static/js/9780.02a27630.chunk.js +10 -0
- rapidfireai/frontend/build/static/js/9808.d0ca9674.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/9808.d0ca9674.chunk.js.LICENSE.txt +6 -0
- rapidfireai/frontend/build/static/js/9815.b8db3c5d.js +1 -0
- rapidfireai/frontend/build/static/js/9886.2940b53a.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/main~1f912138.fa9d03b1.js +1 -0
- rapidfireai/frontend/build/static/js/main~43dd7041.2e00860d.js +1 -0
- rapidfireai/frontend/build/static/js/main~84781932.68deffff.js +1 -0
- rapidfireai/frontend/build/static/media/404-overflow.fad9a31861b0afba6f921ebb8e769688.svg +32 -0
- rapidfireai/frontend/build/static/media/RapidFire_Square_Bug.27ceb48296314a4bc0d4.png +0 -0
- rapidfireai/frontend/build/static/media/chart-bar.0fd4a63680fba840a7b69fbf07969f79.svg +7 -0
- rapidfireai/frontend/build/static/media/chart-contour.0d4b306f2669f3ad25375568935e3ce3.svg +5 -0
- rapidfireai/frontend/build/static/media/chart-difference.16174216d6f3b7c24f40e3541fe0ca2c.svg +20 -0
- rapidfireai/frontend/build/static/media/chart-image.cc434c4dc50780966344e2385a15f8fe.svg +6 -0
- rapidfireai/frontend/build/static/media/chart-line.0adaa2036bb4eb5956db6d0c7e925a3d.svg +4 -0
- rapidfireai/frontend/build/static/media/chart-parallel.da7dedf539b2af4b654d377c679173e4.svg +7 -0
- rapidfireai/frontend/build/static/media/chart-scatter.69118d0023a6ff3973f7fa913834ac47.svg +9 -0
- rapidfireai/frontend/build/static/media/default-error.f246ddf367c6fbd67942e5a13382a7f1.svg +26 -0
- rapidfireai/frontend/build/static/media/fontawesome-webfont.1e59d2330b4c6deb84b3.ttf +0 -0
- rapidfireai/frontend/build/static/media/fontawesome-webfont.20fd1704ea223900efa9.woff2 +0 -0
- rapidfireai/frontend/build/static/media/fontawesome-webfont.8b43027f47b20503057d.eot +0 -0
- rapidfireai/frontend/build/static/media/fontawesome-webfont.c1e38fd9e0e74ba58f7a.svg +2671 -0
- rapidfireai/frontend/build/static/media/fontawesome-webfont.f691f37e57f04c152e23.woff +0 -0
- rapidfireai/frontend/build/static/media/icon-visible-fill.8d34cd35303828fdfc15154f5536e63b.svg +7 -0
- rapidfireai/frontend/build/static/media/no-experiments.0e4f4a114ef73e7d81c09474aba64b6c.svg +22 -0
- rapidfireai/frontend/build/static/media/parallel-chart-placeholder.234ef0c5b220ef2a5a6fa5bafff173f7.svg +16 -0
- rapidfireai/frontend/build/static/media/permission-denied-lock.16036747d57cd663d7df223781a447b2.svg +14 -0
- rapidfireai/frontend/build/static/media/promo-modal-content.e3b2c6c568ac192b9bec54b838b54850.svg +30 -0
- rapidfireai/frontend/build/static/media/registered-model-grey-ok.8274b58d39504c8d1b8c358aa1c9aa35.svg +23 -0
- rapidfireai/frontend/build/static/media/warning.290a3b14118933547965e91ea61c5a61.svg +3 -0
- rapidfireai/frontend/proxy_middleware.py +233 -0
- rapidfireai/frontend/server.py +25 -0
- rapidfireai/ml/__init__.py +0 -0
- rapidfireai/ml/callbacks.py +176 -0
- rapidfireai/ml/checkpoint_utils.py +540 -0
- rapidfireai/ml/trainer.py +309 -0
- rapidfireai/start.sh +634 -0
- rapidfireai/utils/__init__.py +0 -0
- rapidfireai/utils/automl_utils.py +51 -0
- rapidfireai/utils/constants.py +141 -0
- rapidfireai/utils/datapaths.py +69 -0
- rapidfireai/utils/exceptions.py +82 -0
- rapidfireai/utils/experiment_utils.py +370 -0
- rapidfireai/utils/logging.py +87 -0
- rapidfireai/utils/mlflow_manager.py +121 -0
- rapidfireai/utils/serialize.py +15 -0
- rapidfireai/utils/shm_manager.py +469 -0
- rapidfireai/utils/trainer_config.py +23 -0
- rapidfireai/utils/worker_manager.py +219 -0
- rapidfireai/version.py +6 -0
- rapidfireai-0.9.9.dist-info/METADATA +242 -0
- rapidfireai-0.9.9.dist-info/RECORD +318 -0
- rapidfireai-0.9.9.dist-info/entry_points.txt +2 -0
- rapidfireai-0.0.1.dist-info/METADATA +0 -37
- rapidfireai-0.0.1.dist-info/RECORD +0 -6
- {rapidfireai-0.0.1.dist-info → rapidfireai-0.9.9.dist-info}/WHEEL +0 -0
- {rapidfireai-0.0.1.dist-info → rapidfireai-0.9.9.dist-info}/licenses/LICENSE +0 -0
- {rapidfireai-0.0.1.dist-info → rapidfireai-0.9.9.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,137 @@
|
|
|
1
|
+
"""This module contains the Scheduler class which is responsible for scheduling runs on workers to train on a chunk."""
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class Scheduler:
|
|
5
|
+
"""This class is responsible for scheduling runs on to workers to train on a chunk"""
|
|
6
|
+
|
|
7
|
+
def __init__(self, run_ids: list[int], num_workers: int, num_chunks: int) -> None:
|
|
8
|
+
# run_ids are 1 indexed
|
|
9
|
+
# worker_ids are 0 indexed
|
|
10
|
+
# chunk_ids are 0 indexed
|
|
11
|
+
|
|
12
|
+
self.n_runs: int = len(run_ids)
|
|
13
|
+
self.n_workers: int = num_workers
|
|
14
|
+
self.n_chunks: int = num_chunks
|
|
15
|
+
self.run_ids: list[int] = run_ids
|
|
16
|
+
|
|
17
|
+
# create data structures
|
|
18
|
+
self.worker_running_current_run: dict[int, int] = dict.fromkeys(range(self.n_workers), -1)
|
|
19
|
+
self.run_visited_num_chunks: dict[int, int] = dict.fromkeys(self.run_ids, 0)
|
|
20
|
+
self.run_start_chunk_id: dict[int, int] = dict.fromkeys(self.run_ids, 0)
|
|
21
|
+
|
|
22
|
+
# add runs to scheduler
|
|
23
|
+
for run_id in run_ids:
|
|
24
|
+
self.add_run(run_id, 0)
|
|
25
|
+
|
|
26
|
+
def reset_run(self, run_id: int) -> None:
|
|
27
|
+
"""Reset the scheduler for a specific run (used at epoch boundaries)"""
|
|
28
|
+
if run_id in self.run_ids:
|
|
29
|
+
# Reset progress for this run
|
|
30
|
+
self.run_visited_num_chunks[run_id] = 0
|
|
31
|
+
|
|
32
|
+
# If this run is currently assigned to a worker, free the worker
|
|
33
|
+
for worker_id in range(self.n_workers):
|
|
34
|
+
if self.worker_running_current_run[worker_id] == run_id:
|
|
35
|
+
self.worker_running_current_run[worker_id] = -1
|
|
36
|
+
|
|
37
|
+
def add_run(self, run_id: int, run_visited_num_chunks: int, run_start_chunk_id: int = 0) -> None:
|
|
38
|
+
"""Add a new run to the scheduler."""
|
|
39
|
+
if run_id not in self.run_ids:
|
|
40
|
+
self.run_ids.append(run_id)
|
|
41
|
+
self.n_runs = len(self.run_ids)
|
|
42
|
+
|
|
43
|
+
self.run_visited_num_chunks[run_id] = run_visited_num_chunks
|
|
44
|
+
self.run_start_chunk_id[run_id] = run_start_chunk_id
|
|
45
|
+
|
|
46
|
+
def set_completed_task(self, worker_id: int) -> None:
|
|
47
|
+
"""Set a task as completed."""
|
|
48
|
+
run_id = self.worker_running_current_run[worker_id]
|
|
49
|
+
|
|
50
|
+
if run_id != -1:
|
|
51
|
+
self.worker_running_current_run[worker_id] = -1
|
|
52
|
+
self.run_visited_num_chunks[run_id] += 1
|
|
53
|
+
|
|
54
|
+
def remove_run(self, run_id: int) -> int:
|
|
55
|
+
"""Remove a run from the scheduler and return its progress."""
|
|
56
|
+
if run_id not in self.run_ids:
|
|
57
|
+
return 0
|
|
58
|
+
|
|
59
|
+
# Get the progress before removing
|
|
60
|
+
progress = self.run_visited_num_chunks.get(run_id, 0)
|
|
61
|
+
|
|
62
|
+
# Clean up worker assignment
|
|
63
|
+
for worker_id in range(self.n_workers):
|
|
64
|
+
if self.worker_running_current_run[worker_id] == run_id:
|
|
65
|
+
self.worker_running_current_run[worker_id] = -1
|
|
66
|
+
|
|
67
|
+
# Remove from all data structures
|
|
68
|
+
self.run_visited_num_chunks.pop(run_id, None)
|
|
69
|
+
self.run_start_chunk_id.pop(run_id, None)
|
|
70
|
+
|
|
71
|
+
if run_id in self.run_ids:
|
|
72
|
+
self.run_ids.remove(run_id)
|
|
73
|
+
self.n_runs = len(self.run_ids)
|
|
74
|
+
|
|
75
|
+
return progress
|
|
76
|
+
|
|
77
|
+
def schedule(self) -> dict[str, int | bool | None] | None:
|
|
78
|
+
"""
|
|
79
|
+
Schedule a single task based on constraints and preferences.
|
|
80
|
+
Returns {run_id: <>, worker_id: <>, chunk_id: <>, is_last_chunk: <>} if a schedule is possible.
|
|
81
|
+
Returns {run_id: None, worker_id: None, chunk_id: None, is_last_chunk: None} if all runs have seen all chunks.
|
|
82
|
+
Returns {run_id: -1, worker_id: -1, chunk_id: -1, is_last_chunk: None} if all workers are busy or no runs are available.
|
|
83
|
+
"""
|
|
84
|
+
# First check if all workers are busy (most common condition)
|
|
85
|
+
available_workers = [
|
|
86
|
+
worker_id for worker_id in range(self.n_workers) if self.worker_running_current_run[worker_id] == -1
|
|
87
|
+
]
|
|
88
|
+
if not available_workers:
|
|
89
|
+
return {"run_id": -1, "worker_id": -1, "chunk_id": -1, "is_last_chunk": None}
|
|
90
|
+
|
|
91
|
+
# Next check if all runs have seen all chunks (termination condition)
|
|
92
|
+
if all(self.run_visited_num_chunks[run_id] >= self.n_chunks for run_id in self.run_ids):
|
|
93
|
+
return {"run_id": None, "worker_id": None, "chunk_id": None, "is_last_chunk": None}
|
|
94
|
+
|
|
95
|
+
# Get busy runs and available runs
|
|
96
|
+
busy_runs = set(run_id for run_id in self.worker_running_current_run.values() if run_id != -1)
|
|
97
|
+
available_runs = [
|
|
98
|
+
run_id
|
|
99
|
+
for run_id in self.run_ids
|
|
100
|
+
if self.run_visited_num_chunks[run_id] < self.n_chunks and run_id not in busy_runs
|
|
101
|
+
]
|
|
102
|
+
|
|
103
|
+
# If no available runs, return busy state
|
|
104
|
+
if not available_runs:
|
|
105
|
+
return {"run_id": -1, "worker_id": -1, "chunk_id": -1, "is_last_chunk": None}
|
|
106
|
+
|
|
107
|
+
# Find the run with least progress, then lowest run_id for tie-breaking
|
|
108
|
+
# NOTE: any newly inserted clones will take priority
|
|
109
|
+
# NOTE: prioritize by run_id if run_visited_num_chunks is the same
|
|
110
|
+
|
|
111
|
+
run_id = min(available_runs, key=lambda run_id: (self.run_visited_num_chunks[run_id], run_id))
|
|
112
|
+
worker_id = available_workers[0] # Pick first available worker
|
|
113
|
+
chunk_id = (
|
|
114
|
+
self.run_visited_num_chunks[run_id] + self.run_start_chunk_id[run_id]
|
|
115
|
+
) % self.n_chunks # Next chunk in sequence starting from run_start_chunk_id (chunk_id is 0-indexed)
|
|
116
|
+
is_last_chunk = chunk_id == self.n_chunks - 1
|
|
117
|
+
|
|
118
|
+
# Update internal state immediately
|
|
119
|
+
self.worker_running_current_run[worker_id] = run_id
|
|
120
|
+
|
|
121
|
+
return {"run_id": run_id, "worker_id": worker_id, "chunk_id": chunk_id, "is_last_chunk": is_last_chunk}
|
|
122
|
+
|
|
123
|
+
def get_status(self) -> dict:
|
|
124
|
+
"""Get current scheduler status for debugging."""
|
|
125
|
+
completed_runs = [run_id for run_id in self.run_ids if self.run_visited_num_chunks[run_id] == self.n_chunks]
|
|
126
|
+
|
|
127
|
+
return {
|
|
128
|
+
"active_runs": len([r for r in self.run_ids if self.run_visited_num_chunks[r] < self.n_chunks]),
|
|
129
|
+
"busy_workers": len([w for w in range(self.n_workers) if self.worker_running_current_run[w] != -1]),
|
|
130
|
+
"completed_runs": len(completed_runs),
|
|
131
|
+
"worker_assignments": {
|
|
132
|
+
w: self.worker_running_current_run[w]
|
|
133
|
+
for w in range(self.n_workers)
|
|
134
|
+
if self.worker_running_current_run[w] != -1
|
|
135
|
+
},
|
|
136
|
+
"run_progress": {r: f"{self.run_visited_num_chunks[r]}/{self.n_chunks}" for r in self.run_ids},
|
|
137
|
+
}
|
|
@@ -0,0 +1,272 @@
|
|
|
1
|
+
"""This module contains the Worker class which is responsible for handling the worker operations."""
|
|
2
|
+
|
|
3
|
+
import gc
|
|
4
|
+
import os
|
|
5
|
+
import time
|
|
6
|
+
import traceback
|
|
7
|
+
from contextlib import redirect_stderr, redirect_stdout
|
|
8
|
+
from io import StringIO
|
|
9
|
+
from logging import Logger
|
|
10
|
+
from multiprocessing import Process
|
|
11
|
+
from multiprocessing.managers import DictProxy
|
|
12
|
+
from multiprocessing.synchronize import Event as EventType
|
|
13
|
+
from multiprocessing.synchronize import Lock
|
|
14
|
+
from typing import Any, Callable
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
|
|
18
|
+
from rapidfireai.backend.chunks import DatasetChunks
|
|
19
|
+
from rapidfireai.db.rf_db import RfDb
|
|
20
|
+
from rapidfireai.ml.checkpoint_utils import (
|
|
21
|
+
save_checkpoint_to_disk,
|
|
22
|
+
save_checkpoint_to_shared_memory,
|
|
23
|
+
save_model_to_shared_memory,
|
|
24
|
+
)
|
|
25
|
+
from rapidfireai.ml.trainer import create_trainer_instance
|
|
26
|
+
from rapidfireai.utils.constants import MLFLOW_URL, RunStatus, TaskStatus, WorkerTask
|
|
27
|
+
from rapidfireai.utils.datapaths import DataPath
|
|
28
|
+
from rapidfireai.utils.exceptions import WorkerException
|
|
29
|
+
from rapidfireai.utils.logging import RFLogger, TrainingLogger
|
|
30
|
+
from rapidfireai.utils.mlflow_manager import MLflowManager
|
|
31
|
+
from rapidfireai.utils.serialize import decode_db_payload
|
|
32
|
+
from rapidfireai.utils.shm_manager import SharedMemoryManager
|
|
33
|
+
from rapidfireai.utils.trainer_config import TrainerConfig
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class Worker:
|
|
37
|
+
"""Worker class that handles training and validation of runs"""
|
|
38
|
+
|
|
39
|
+
def __init__(self, worker_id: int, model_registry: DictProxy, process_lock: Lock, shutdown_event: EventType):
|
|
40
|
+
"""Initialize the worker"""
|
|
41
|
+
self.process: Process
|
|
42
|
+
self.worker_id: int = worker_id
|
|
43
|
+
self.shutdown_event: EventType = shutdown_event
|
|
44
|
+
|
|
45
|
+
# Shared memory attributes (set by WorkerManager)
|
|
46
|
+
self.model_registry: DictProxy[int, Any] = model_registry
|
|
47
|
+
self.process_lock: Lock = process_lock
|
|
48
|
+
|
|
49
|
+
# Shared memory manager will be created using global objects
|
|
50
|
+
self.shm_manager = SharedMemoryManager(
|
|
51
|
+
name=f"worker-{worker_id}-shm", registry=model_registry, multiprocess_lock=process_lock
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
# create logger
|
|
55
|
+
self.logger: Logger = RFLogger().create_logger(f"worker_{worker_id}")
|
|
56
|
+
self.training_logger: Logger = TrainingLogger().create_logger(f"worker_{worker_id}")
|
|
57
|
+
self.logger.debug(f"Worker {self.worker_id} initialized with PID {os.getpid()}")
|
|
58
|
+
|
|
59
|
+
# create database object
|
|
60
|
+
self.db: RfDb = RfDb()
|
|
61
|
+
|
|
62
|
+
# get experiment name
|
|
63
|
+
self.experiment_name: str = self.db.get_running_experiment()["experiment_name"]
|
|
64
|
+
|
|
65
|
+
# create mlflow manager
|
|
66
|
+
self.mlflow_manager: MLflowManager = MLflowManager(MLFLOW_URL)
|
|
67
|
+
self.mlflow_manager.get_experiment(self.experiment_name)
|
|
68
|
+
|
|
69
|
+
# initialize data paths
|
|
70
|
+
DataPath.initialize(self.experiment_name, self.db.get_experiments_path(self.experiment_name))
|
|
71
|
+
|
|
72
|
+
# load datasets
|
|
73
|
+
train_dataset, self.eval_dataset, self.num_chunks = self.load_datasets()
|
|
74
|
+
self.len_train_dataset = len(train_dataset)
|
|
75
|
+
self.train_dataset_chunks = DatasetChunks(train_dataset, self.num_chunks)
|
|
76
|
+
|
|
77
|
+
def load_datasets(self) -> tuple[torch.utils.data.Dataset | None, torch.utils.data.Dataset | None, int]:
|
|
78
|
+
"""Load the train and eval datasets"""
|
|
79
|
+
try:
|
|
80
|
+
with open(DataPath.dataset_path(), "rb") as f:
|
|
81
|
+
datasets = decode_db_payload(f.read())
|
|
82
|
+
self.logger.debug("Loaded datasets")
|
|
83
|
+
return datasets["train"], datasets["eval"], datasets["num_chunks"]
|
|
84
|
+
except Exception as e:
|
|
85
|
+
raise WorkerException(f"Error loading datasets: {e}") from e
|
|
86
|
+
|
|
87
|
+
def run_fit(
|
|
88
|
+
self,
|
|
89
|
+
run_id: int,
|
|
90
|
+
chunk_id: int,
|
|
91
|
+
create_model_fn: Callable,
|
|
92
|
+
) -> None:
|
|
93
|
+
"""Run fit"""
|
|
94
|
+
self.logger.debug(f"Received run_fit on worker for run {run_id} with chunk {chunk_id}")
|
|
95
|
+
|
|
96
|
+
# get run details
|
|
97
|
+
run_details = self.db.get_run(run_id)
|
|
98
|
+
config_leaf = run_details["config_leaf"]
|
|
99
|
+
mlflow_run_id = run_details["mlflow_run_id"]
|
|
100
|
+
|
|
101
|
+
# set seed
|
|
102
|
+
# torch.manual_seed(run_details["seed"])
|
|
103
|
+
# np.random.seed(run_details["seed"])
|
|
104
|
+
# random.seed(run_details["seed"])
|
|
105
|
+
|
|
106
|
+
# fetch train dataset chunk
|
|
107
|
+
train_dataset_chunk = self.train_dataset_chunks.get_chunk(chunk_id)
|
|
108
|
+
# create worker config
|
|
109
|
+
trainer_config = TrainerConfig(
|
|
110
|
+
worker_id=self.worker_id,
|
|
111
|
+
run_id=run_id,
|
|
112
|
+
mlflow_run_id=mlflow_run_id,
|
|
113
|
+
config_leaf=config_leaf,
|
|
114
|
+
total_steps=run_details["total_steps"],
|
|
115
|
+
completed_steps=run_details["completed_steps"],
|
|
116
|
+
create_model_fn=create_model_fn,
|
|
117
|
+
train_dataset=train_dataset_chunk,
|
|
118
|
+
eval_dataset=self.eval_dataset,
|
|
119
|
+
warm_started_from=run_details["warm_started_from"],
|
|
120
|
+
num_epochs_completed=run_details["num_epochs_completed"],
|
|
121
|
+
)
|
|
122
|
+
completed_steps = self.db.get_completed_steps(run_id)
|
|
123
|
+
|
|
124
|
+
use_shared_memory = True
|
|
125
|
+
parent_run_details = None
|
|
126
|
+
if trainer_config.warm_started_from is not None:
|
|
127
|
+
parent_run_details = self.db.get_run(trainer_config.warm_started_from)
|
|
128
|
+
if trainer_config.config_leaf.get("trainer_type") == "GRPO":
|
|
129
|
+
config_leaf["reward_funcs"] = parent_run_details["config_leaf"].get("reward_funcs")
|
|
130
|
+
self.db.set_run_details(run_id, config_leaf=config_leaf)
|
|
131
|
+
|
|
132
|
+
stdout_buffer = StringIO()
|
|
133
|
+
stderr_buffer = StringIO()
|
|
134
|
+
with redirect_stdout(stdout_buffer), redirect_stderr(stderr_buffer):
|
|
135
|
+
trainer_instance, base_model_name = create_trainer_instance(
|
|
136
|
+
trainer_config, self.shm_manager, use_shared_memory, self.mlflow_manager, chunk_id
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
# if first time, save checkpoint to disk
|
|
140
|
+
if completed_steps == 0 and not use_shared_memory:
|
|
141
|
+
save_checkpoint_to_disk(trainer_instance, trainer_config, first=True)
|
|
142
|
+
|
|
143
|
+
# write logs to user logger
|
|
144
|
+
if stdout_buffer.getvalue():
|
|
145
|
+
self.training_logger.info(stdout_buffer.getvalue())
|
|
146
|
+
if stderr_buffer.getvalue():
|
|
147
|
+
self.training_logger.error(stderr_buffer.getvalue())
|
|
148
|
+
|
|
149
|
+
self.logger.debug(f"Beginning training for run {run_id} on chunk {chunk_id}")
|
|
150
|
+
|
|
151
|
+
# update base model name in db for run
|
|
152
|
+
trainer_config.config_leaf["model_name"] = trainer_instance.model.config._name_or_path
|
|
153
|
+
self.db.set_run_details(run_id, config_leaf=trainer_config.config_leaf)
|
|
154
|
+
|
|
155
|
+
# Train the model
|
|
156
|
+
stdout_buffer = StringIO()
|
|
157
|
+
stderr_buffer = StringIO()
|
|
158
|
+
with redirect_stdout(stdout_buffer), redirect_stderr(stderr_buffer):
|
|
159
|
+
trainer_instance.train()
|
|
160
|
+
|
|
161
|
+
# write logs to user logger
|
|
162
|
+
if stdout_buffer.getvalue():
|
|
163
|
+
self.training_logger.info(stdout_buffer.getvalue())
|
|
164
|
+
if stderr_buffer.getvalue():
|
|
165
|
+
self.training_logger.error(stderr_buffer.getvalue())
|
|
166
|
+
|
|
167
|
+
# update completed steps
|
|
168
|
+
new_completed_steps = completed_steps + trainer_instance.state.global_step
|
|
169
|
+
self.db.set_completed_steps(run_id, new_completed_steps)
|
|
170
|
+
|
|
171
|
+
save_strategy = trainer_config.config_leaf.get("training_args", {}).get("save_strategy", "epoch")
|
|
172
|
+
|
|
173
|
+
# Save checkpoints to shared memory
|
|
174
|
+
if use_shared_memory:
|
|
175
|
+
save_checkpoint_to_shared_memory(trainer_instance, trainer_config, self.shm_manager)
|
|
176
|
+
if not trainer_config.config_leaf.get("peft_params"):
|
|
177
|
+
save_model_to_shared_memory(
|
|
178
|
+
trainer_instance.model,
|
|
179
|
+
trainer_instance.tokenizer,
|
|
180
|
+
trainer_config,
|
|
181
|
+
self.shm_manager,
|
|
182
|
+
"full_model",
|
|
183
|
+
trainer_config.run_id,
|
|
184
|
+
)
|
|
185
|
+
self.logger.debug(f"Saved checkpoint to shared memory for run {run_id} on chunk {chunk_id}")
|
|
186
|
+
if save_strategy == "chunk" or (save_strategy == "epoch" and chunk_id == self.num_chunks - 1):
|
|
187
|
+
save_checkpoint_to_disk(trainer_instance, trainer_config, completed_steps=new_completed_steps)
|
|
188
|
+
self.logger.debug(f"Saved checkpoint to disk for run {run_id} on chunk {chunk_id}")
|
|
189
|
+
else: # save checkpoint to disk when not using shared memory
|
|
190
|
+
save_checkpoint_to_disk(trainer_instance, trainer_config, completed_steps=new_completed_steps)
|
|
191
|
+
self.logger.debug(f"Saved checkpoint to disk for run {run_id} on chunk {chunk_id}")
|
|
192
|
+
|
|
193
|
+
if chunk_id == self.num_chunks - 1 and new_completed_steps >= trainer_config.total_steps:
|
|
194
|
+
save_checkpoint_to_disk(trainer_instance, trainer_config, last=True)
|
|
195
|
+
self.logger.debug(f"Saved final checkpoint for run {run_id} on chunk {chunk_id}")
|
|
196
|
+
|
|
197
|
+
# clean up all references to shared memory objects
|
|
198
|
+
if hasattr(trainer_instance, "model"):
|
|
199
|
+
del trainer_instance.model
|
|
200
|
+
if hasattr(trainer_instance, "ref_model"):
|
|
201
|
+
del trainer_instance.ref_model
|
|
202
|
+
if hasattr(trainer_instance, "optimizer"):
|
|
203
|
+
del trainer_instance.optimizer
|
|
204
|
+
if hasattr(trainer_instance, "lr_scheduler"):
|
|
205
|
+
del trainer_instance.lr_scheduler
|
|
206
|
+
del trainer_instance
|
|
207
|
+
|
|
208
|
+
# run garbage collection
|
|
209
|
+
gc.collect()
|
|
210
|
+
if torch.cuda.is_available():
|
|
211
|
+
torch.cuda.empty_cache()
|
|
212
|
+
|
|
213
|
+
self.logger.debug(f"Completed training for run {run_id} on chunk {chunk_id}")
|
|
214
|
+
|
|
215
|
+
def serve_forever(self) -> None:
|
|
216
|
+
"""This runs in the worker process"""
|
|
217
|
+
|
|
218
|
+
prev_task_id: int | None = None
|
|
219
|
+
while not (self.shutdown_event and self.shutdown_event.is_set()):
|
|
220
|
+
try:
|
|
221
|
+
scheduled_task = self.db.get_worker_scheduled_task(self.worker_id)
|
|
222
|
+
if not scheduled_task or scheduled_task["task_id"] == prev_task_id:
|
|
223
|
+
# no new tasks or same task as previous iteration
|
|
224
|
+
time.sleep(1)
|
|
225
|
+
continue
|
|
226
|
+
|
|
227
|
+
# get task details
|
|
228
|
+
prev_task_id = scheduled_task["task_id"]
|
|
229
|
+
task_type = scheduled_task["task_type"]
|
|
230
|
+
run_id = scheduled_task["run_id"]
|
|
231
|
+
chunk_id = scheduled_task["chunk_id"]
|
|
232
|
+
create_model_fn = scheduled_task["config_options"]["create_model_fn"]
|
|
233
|
+
self.logger.debug(f"Received task {task_type} for run {run_id}")
|
|
234
|
+
|
|
235
|
+
if task_type == WorkerTask.TRAIN_VAL:
|
|
236
|
+
self.db.set_worker_task_status(self.worker_id, TaskStatus.IN_PROGRESS)
|
|
237
|
+
|
|
238
|
+
# run train and validation function
|
|
239
|
+
try:
|
|
240
|
+
self.run_fit(run_id, chunk_id, create_model_fn)
|
|
241
|
+
self.db.set_worker_task_status(self.worker_id, TaskStatus.COMPLETED)
|
|
242
|
+
except Exception as e:
|
|
243
|
+
self.logger.opt(exception=True).error(
|
|
244
|
+
f"Error while running run_fit for run {run_id} and chunk {chunk_id}: {e}"
|
|
245
|
+
)
|
|
246
|
+
self.db.set_run_details(run_id, status=RunStatus.FAILED, error=str(e) + traceback.format_exc())
|
|
247
|
+
self.db.set_worker_task_status(self.worker_id, TaskStatus.FAILED)
|
|
248
|
+
else:
|
|
249
|
+
raise WorkerException(f"Invalid task type: {task_type}")
|
|
250
|
+
except Exception as e:
|
|
251
|
+
self.logger.opt(exception=True).error(f"Worker {self.worker_id} error: {e}")
|
|
252
|
+
self.db.set_experiment_error(str(e) + "\n" + traceback.format_exc())
|
|
253
|
+
break
|
|
254
|
+
|
|
255
|
+
self.shutdown()
|
|
256
|
+
|
|
257
|
+
def shutdown(self):
|
|
258
|
+
"""Called by WorkerManager to gracefully shutdown this worker"""
|
|
259
|
+
self.logger.debug(f"Worker {self.worker_id} shutdown requested")
|
|
260
|
+
if self.shutdown_event:
|
|
261
|
+
self.shutdown_event.set()
|
|
262
|
+
|
|
263
|
+
# Close database connection to prevent resource leaks
|
|
264
|
+
try:
|
|
265
|
+
if hasattr(self, "db"):
|
|
266
|
+
self.db.close()
|
|
267
|
+
except Exception as e:
|
|
268
|
+
self.logger.debug(f"Error closing database connection: {e}")
|
|
269
|
+
|
|
270
|
+
def is_alive(self):
|
|
271
|
+
"""Check if the worker process is alive"""
|
|
272
|
+
return self.process and self.process.is_alive()
|