rapidfireai 0.0.1__py3-none-any.whl → 0.9.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rapidfireai might be problematic. Click here for more details.
- rapidfireai/__init__.py +11 -5
- rapidfireai/automl/__init__.py +20 -0
- rapidfireai/automl/base.py +48 -0
- rapidfireai/automl/datatypes.py +42 -0
- rapidfireai/automl/grid_search.py +125 -0
- rapidfireai/automl/model_config.py +102 -0
- rapidfireai/automl/random_search.py +145 -0
- rapidfireai/backend/__init__.py +0 -0
- rapidfireai/backend/chunks.py +63 -0
- rapidfireai/backend/controller.py +637 -0
- rapidfireai/backend/scheduler.py +137 -0
- rapidfireai/backend/worker.py +272 -0
- rapidfireai/cli.py +380 -0
- rapidfireai/db/__init__.py +0 -0
- rapidfireai/db/db_interface.py +135 -0
- rapidfireai/db/rf_db.py +694 -0
- rapidfireai/db/tables.sql +64 -0
- rapidfireai/dispatcher/dispatcher.py +391 -0
- rapidfireai/dispatcher/gunicorn.conf.py +25 -0
- rapidfireai/experiment.py +168 -0
- rapidfireai/frontend/build/asset-manifest.json +276 -0
- rapidfireai/frontend/build/favicon.ico +0 -0
- rapidfireai/frontend/build/index.html +1 -0
- rapidfireai/frontend/build/manifest.json +15 -0
- rapidfireai/frontend/build/pdf.worker.js +1 -0
- rapidfireai/frontend/build/report.html +39 -0
- rapidfireai/frontend/build/static/css/1482.3b7bf531.chunk.css +1 -0
- rapidfireai/frontend/build/static/css/2730.3f8937ff.chunk.css +1 -0
- rapidfireai/frontend/build/static/css/318.0def90a7.css +7 -0
- rapidfireai/frontend/build/static/css/4762.9b7b71f7.chunk.css +1 -0
- rapidfireai/frontend/build/static/css/4950.487ecc8b.chunk.css +1 -0
- rapidfireai/frontend/build/static/css/5170.2574ce9d.chunk.css +1 -0
- rapidfireai/frontend/build/static/css/6121.4d541986.chunk.css +1 -0
- rapidfireai/frontend/build/static/css/6343.dd6979f2.chunk.css +1 -0
- rapidfireai/frontend/build/static/css/6534.433c213f.chunk.css +1 -0
- rapidfireai/frontend/build/static/css/6920.ffac4b2a.css +2 -0
- rapidfireai/frontend/build/static/css/7246.bf2f0c87.css +9 -0
- rapidfireai/frontend/build/static/css/7367.dd6979f2.chunk.css +1 -0
- rapidfireai/frontend/build/static/css/8690.05d081e5.chunk.css +1 -0
- rapidfireai/frontend/build/static/css/9531.d0910d3c.chunk.css +1 -0
- rapidfireai/frontend/build/static/css/9780.363e4943.chunk.css +1 -0
- rapidfireai/frontend/build/static/css/main~d91a9049.c0be472c.css +1 -0
- rapidfireai/frontend/build/static/js/1000.e5ed264b.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/1012.ac98ab59.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/1079.6c13ac0d.js +1 -0
- rapidfireai/frontend/build/static/js/110.9059f3b8.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/1142.872d0010.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/1167.9a6da14c.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/1248.60890b4f.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/1262.83dc7673.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/1273.56da3e13.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/1273.56da3e13.chunk.js.LICENSE.txt +9 -0
- rapidfireai/frontend/build/static/js/1303.7d19305c.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/1351.45076ff3.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/1355.b896a592.js +1 -0
- rapidfireai/frontend/build/static/js/1357.02c46a02.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/1470.c51d60c6.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/1482.23b74f50.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/1500.19799d8d.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/1648.d3b9edc7.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/1860.7d96e3f9.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/1909.5b1d9ff4.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/1928.44245110.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/1928.44245110.chunk.js.LICENSE.txt +11 -0
- rapidfireai/frontend/build/static/js/1933.deba26ca.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/21.aac92802.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/2103.0ca12071.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/2258.b3b8fab4.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/2289.9ad51e87.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/2323.7dd927d7.js +2 -0
- rapidfireai/frontend/build/static/js/2323.7dd927d7.js.LICENSE.txt +1 -0
- rapidfireai/frontend/build/static/js/2346.ed99ca72.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/2386.0a660834.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/2402.465048f9.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/243.5a83bbca.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/2589.68571e16.js +1 -0
- rapidfireai/frontend/build/static/js/2647.65092bab.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/2691.65d4a4e7.js +1 -0
- rapidfireai/frontend/build/static/js/2730.b38dd6f3.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/2746.ef752da4.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/2779.580d4491.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/2799.fe5993b2.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/2844.9708db79.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/2844.9708db79.chunk.js.LICENSE.txt +21 -0
- rapidfireai/frontend/build/static/js/2901.ee0c606b.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/2932.7cc0689b.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/2932.7cc0689b.chunk.js.LICENSE.txt +6 -0
- rapidfireai/frontend/build/static/js/2956.a393c8cc.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/2972.679bed05.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/2985.7e51cdfa.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/2985.7e51cdfa.chunk.js.LICENSE.txt +51 -0
- rapidfireai/frontend/build/static/js/3093.488df653.js +1 -0
- rapidfireai/frontend/build/static/js/3145.66ee61b9.js +1 -0
- rapidfireai/frontend/build/static/js/3170.a22f966a.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/3170.a22f966a.chunk.js.LICENSE.txt +21 -0
- rapidfireai/frontend/build/static/js/3307.f6fb258c.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/3325.d5b03d65.js +1 -0
- rapidfireai/frontend/build/static/js/3334.2d6704df.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/3334.2d6704df.chunk.js.LICENSE.txt +6 -0
- rapidfireai/frontend/build/static/js/3387.bb8edad3.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/3448.438e6579.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/3460.735eea87.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/3505.7fd3921a.js +2 -0
- rapidfireai/frontend/build/static/js/3505.7fd3921a.js.LICENSE.txt +9 -0
- rapidfireai/frontend/build/static/js/3510.cd167a00.js +2 -0
- rapidfireai/frontend/build/static/js/3510.cd167a00.js.LICENSE.txt +18 -0
- rapidfireai/frontend/build/static/js/3563.cc828e19.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/359.08960b84.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/359.08960b84.chunk.js.LICENSE.txt +4 -0
- rapidfireai/frontend/build/static/js/3608.403b4b79.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/3652.cb8add7f.js +1 -0
- rapidfireai/frontend/build/static/js/3775.5230b157.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/3817.53555d18.js +2 -0
- rapidfireai/frontend/build/static/js/3817.53555d18.js.LICENSE.txt +18 -0
- rapidfireai/frontend/build/static/js/3835.d9946ff9.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/3964.874f0297.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/3968.275cbc3d.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/3999.765cbd82.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/4020.4452c046.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/4138.2f6f6d9f.js +1 -0
- rapidfireai/frontend/build/static/js/4160.f424554c.js +1 -0
- rapidfireai/frontend/build/static/js/4180.50cea095.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/4221.b0bba3f5.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/4250.5bb49278.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/4297.15777d8f.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/4349.c965f2de.js +2 -0
- rapidfireai/frontend/build/static/js/4349.c965f2de.js.LICENSE.txt +1 -0
- rapidfireai/frontend/build/static/js/4484.4cbe5e7f.js +2 -0
- rapidfireai/frontend/build/static/js/4484.4cbe5e7f.js.LICENSE.txt +10 -0
- rapidfireai/frontend/build/static/js/4578.a8124588.js +1 -0
- rapidfireai/frontend/build/static/js/4596.89a97480.js +1 -0
- rapidfireai/frontend/build/static/js/4748.566f435a.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/4762.928e8a90.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/4768.7945be63.js +2 -0
- rapidfireai/frontend/build/static/js/4768.7945be63.js.LICENSE.txt +1 -0
- rapidfireai/frontend/build/static/js/4804.26b50dd4.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/4850.62390a45.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/4862.a0ccb221.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/491.5dc8ed40.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/492.9262f038.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/492.9262f038.chunk.js.LICENSE.txt +6 -0
- rapidfireai/frontend/build/static/js/4943.6d345fd3.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/4950.bc182e62.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/5042.d4f0c65a.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/5042.d4f0c65a.chunk.js.LICENSE.txt +6 -0
- rapidfireai/frontend/build/static/js/5170.0065e96f.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/5222.35c74a52.js +2 -0
- rapidfireai/frontend/build/static/js/5222.35c74a52.js.LICENSE.txt +10 -0
- rapidfireai/frontend/build/static/js/5223.3224f019.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/5223.3224f019.chunk.js.LICENSE.txt +3 -0
- rapidfireai/frontend/build/static/js/5229.7dd42316.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/5286.4c1ad26b.js +1 -0
- rapidfireai/frontend/build/static/js/5486.21cff711.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/5526.7b368956.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/5605.1ee4d87b.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/5682.40b42d8b.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/5794.9433d867.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/5826.38a56e8c.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/5826.38a56e8c.chunk.js.LICENSE.txt +1 -0
- rapidfireai/frontend/build/static/js/5862.50f42a0b.js +1 -0
- rapidfireai/frontend/build/static/js/5895.e26742f1.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/5919.edd4a5cf.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/598.a0e792ae.js +1 -0
- rapidfireai/frontend/build/static/js/6058.74162bf9.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/618.06051134.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/618.06051134.chunk.js.LICENSE.txt +21 -0
- rapidfireai/frontend/build/static/js/6335.9fca442d.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/6336.e05e1154.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/6343.2bcd28ff.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/6363.a319b8f2.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/6478.344abf25.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/6504.1c004564.js +1 -0
- rapidfireai/frontend/build/static/js/6534.ec7e149b.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/6715.55a5c19c.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/6756.e6cb993c.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/6756.e6cb993c.chunk.js.LICENSE.txt +10 -0
- rapidfireai/frontend/build/static/js/6762.acfde9fd.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/6762.acfde9fd.chunk.js.LICENSE.txt +19 -0
- rapidfireai/frontend/build/static/js/6846.67103d0e.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/6861.34cf0198.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/6899.0eaf36a8.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/6899.0eaf36a8.chunk.js.LICENSE.txt +5 -0
- rapidfireai/frontend/build/static/js/6933.8b564944.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/699.d0437920.js +1 -0
- rapidfireai/frontend/build/static/js/7076.4182f63a.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/7186.42ad86d5.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/7248.a46635fd.js +1 -0
- rapidfireai/frontend/build/static/js/725.6b15a14a.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/7266.3575539d.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/7270.0a1e84fc.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/7270.0a1e84fc.chunk.js.LICENSE.txt +6 -0
- rapidfireai/frontend/build/static/js/7367.7120474f.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/7436.8e226055.js +1 -0
- rapidfireai/frontend/build/static/js/7504.ef223844.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/7603.ee049fe3.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/7670.2835b49a.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/7670.2835b49a.chunk.js.LICENSE.txt +6 -0
- rapidfireai/frontend/build/static/js/7721.7390b3cc.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/7731.5796cced.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/775.660a5deb.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/775.660a5deb.chunk.js.LICENSE.txt +6 -0
- rapidfireai/frontend/build/static/js/7832.7976a3e4.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/7844.72cc2e81.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/7948.48eab032.js +1 -0
- rapidfireai/frontend/build/static/js/7972.085079d4.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/7972.085079d4.chunk.js.LICENSE.txt +6 -0
- rapidfireai/frontend/build/static/js/8017.a9e7dc5a.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/8023.75f1f3df.js +2 -0
- rapidfireai/frontend/build/static/js/8023.75f1f3df.js.LICENSE.txt +41 -0
- rapidfireai/frontend/build/static/js/8123.b69db974.js +1 -0
- rapidfireai/frontend/build/static/js/813.065a87e5.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/819.2056f122.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/819.2056f122.chunk.js.LICENSE.txt +6 -0
- rapidfireai/frontend/build/static/js/8262.04bc17d1.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/8300.75adcc4f.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/8336.b1d3e764.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/8365.26cf64ea.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/8398.8bca8e0e.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/8398.8bca8e0e.chunk.js.LICENSE.txt +6 -0
- rapidfireai/frontend/build/static/js/847.33ceed50.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/847.33ceed50.chunk.js.LICENSE.txt +6 -0
- rapidfireai/frontend/build/static/js/8486.8ec852a7.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/8497.19378265.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/8541.4c55c9f4.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/8690.e305a804.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/8690.e305a804.chunk.js.LICENSE.txt +6 -0
- rapidfireai/frontend/build/static/js/8712.a9445fe6.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/8763.61761e08.js +1 -0
- rapidfireai/frontend/build/static/js/8823.baf9bffd.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/8823.baf9bffd.chunk.js.LICENSE.txt +6 -0
- rapidfireai/frontend/build/static/js/8867.767462b7.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/8953.c0f88dea.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/8960.357cb1eb.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/8960.357cb1eb.chunk.js.LICENSE.txt +6 -0
- rapidfireai/frontend/build/static/js/9.f4492795.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/9.f4492795.chunk.js.LICENSE.txt +12 -0
- rapidfireai/frontend/build/static/js/9079.88a8d2a3.js +1 -0
- rapidfireai/frontend/build/static/js/9082.37c40520.chunk.js +10 -0
- rapidfireai/frontend/build/static/js/9133.90ae330d.js +2 -0
- rapidfireai/frontend/build/static/js/9133.90ae330d.js.LICENSE.txt +8 -0
- rapidfireai/frontend/build/static/js/9151.1ac359d5.js +2 -0
- rapidfireai/frontend/build/static/js/9151.1ac359d5.js.LICENSE.txt +8 -0
- rapidfireai/frontend/build/static/js/9168.027bf2fd.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/9194.9c5cc548.chunk.js +10 -0
- rapidfireai/frontend/build/static/js/9244.026f4aee.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/936.2e02d037.js +2 -0
- rapidfireai/frontend/build/static/js/936.2e02d037.js.LICENSE.txt +6 -0
- rapidfireai/frontend/build/static/js/9369.7d1a0a1d.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/9427.7c8442e7.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/944.55948859.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/9499.c53a82da.js +2 -0
- rapidfireai/frontend/build/static/js/9499.c53a82da.js.LICENSE.txt +62 -0
- rapidfireai/frontend/build/static/js/9531.3ce05781.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/9547.92fac952.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/9547.92fac952.chunk.js.LICENSE.txt +6 -0
- rapidfireai/frontend/build/static/js/9620.b6e973a7.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/9645.6fddfa65.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/9669.d38dda6d.js +1 -0
- rapidfireai/frontend/build/static/js/9682.41b6b807.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/9720.19d5ae76.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/9720.19d5ae76.chunk.js.LICENSE.txt +23 -0
- rapidfireai/frontend/build/static/js/9723.d3c7fe9e.js +1 -0
- rapidfireai/frontend/build/static/js/9780.02a27630.chunk.js +10 -0
- rapidfireai/frontend/build/static/js/9808.d0ca9674.chunk.js +2 -0
- rapidfireai/frontend/build/static/js/9808.d0ca9674.chunk.js.LICENSE.txt +6 -0
- rapidfireai/frontend/build/static/js/9815.b8db3c5d.js +1 -0
- rapidfireai/frontend/build/static/js/9886.2940b53a.chunk.js +1 -0
- rapidfireai/frontend/build/static/js/main~1f912138.fa9d03b1.js +1 -0
- rapidfireai/frontend/build/static/js/main~43dd7041.2e00860d.js +1 -0
- rapidfireai/frontend/build/static/js/main~84781932.68deffff.js +1 -0
- rapidfireai/frontend/build/static/media/404-overflow.fad9a31861b0afba6f921ebb8e769688.svg +32 -0
- rapidfireai/frontend/build/static/media/RapidFire_Square_Bug.27ceb48296314a4bc0d4.png +0 -0
- rapidfireai/frontend/build/static/media/chart-bar.0fd4a63680fba840a7b69fbf07969f79.svg +7 -0
- rapidfireai/frontend/build/static/media/chart-contour.0d4b306f2669f3ad25375568935e3ce3.svg +5 -0
- rapidfireai/frontend/build/static/media/chart-difference.16174216d6f3b7c24f40e3541fe0ca2c.svg +20 -0
- rapidfireai/frontend/build/static/media/chart-image.cc434c4dc50780966344e2385a15f8fe.svg +6 -0
- rapidfireai/frontend/build/static/media/chart-line.0adaa2036bb4eb5956db6d0c7e925a3d.svg +4 -0
- rapidfireai/frontend/build/static/media/chart-parallel.da7dedf539b2af4b654d377c679173e4.svg +7 -0
- rapidfireai/frontend/build/static/media/chart-scatter.69118d0023a6ff3973f7fa913834ac47.svg +9 -0
- rapidfireai/frontend/build/static/media/default-error.f246ddf367c6fbd67942e5a13382a7f1.svg +26 -0
- rapidfireai/frontend/build/static/media/fontawesome-webfont.1e59d2330b4c6deb84b3.ttf +0 -0
- rapidfireai/frontend/build/static/media/fontawesome-webfont.20fd1704ea223900efa9.woff2 +0 -0
- rapidfireai/frontend/build/static/media/fontawesome-webfont.8b43027f47b20503057d.eot +0 -0
- rapidfireai/frontend/build/static/media/fontawesome-webfont.c1e38fd9e0e74ba58f7a.svg +2671 -0
- rapidfireai/frontend/build/static/media/fontawesome-webfont.f691f37e57f04c152e23.woff +0 -0
- rapidfireai/frontend/build/static/media/icon-visible-fill.8d34cd35303828fdfc15154f5536e63b.svg +7 -0
- rapidfireai/frontend/build/static/media/no-experiments.0e4f4a114ef73e7d81c09474aba64b6c.svg +22 -0
- rapidfireai/frontend/build/static/media/parallel-chart-placeholder.234ef0c5b220ef2a5a6fa5bafff173f7.svg +16 -0
- rapidfireai/frontend/build/static/media/permission-denied-lock.16036747d57cd663d7df223781a447b2.svg +14 -0
- rapidfireai/frontend/build/static/media/promo-modal-content.e3b2c6c568ac192b9bec54b838b54850.svg +30 -0
- rapidfireai/frontend/build/static/media/registered-model-grey-ok.8274b58d39504c8d1b8c358aa1c9aa35.svg +23 -0
- rapidfireai/frontend/build/static/media/warning.290a3b14118933547965e91ea61c5a61.svg +3 -0
- rapidfireai/frontend/proxy_middleware.py +233 -0
- rapidfireai/frontend/server.py +25 -0
- rapidfireai/ml/__init__.py +0 -0
- rapidfireai/ml/callbacks.py +176 -0
- rapidfireai/ml/checkpoint_utils.py +540 -0
- rapidfireai/ml/trainer.py +309 -0
- rapidfireai/start.sh +634 -0
- rapidfireai/utils/__init__.py +0 -0
- rapidfireai/utils/automl_utils.py +51 -0
- rapidfireai/utils/constants.py +141 -0
- rapidfireai/utils/datapaths.py +69 -0
- rapidfireai/utils/exceptions.py +82 -0
- rapidfireai/utils/experiment_utils.py +370 -0
- rapidfireai/utils/logging.py +87 -0
- rapidfireai/utils/mlflow_manager.py +121 -0
- rapidfireai/utils/serialize.py +15 -0
- rapidfireai/utils/shm_manager.py +469 -0
- rapidfireai/utils/trainer_config.py +23 -0
- rapidfireai/utils/worker_manager.py +219 -0
- rapidfireai/version.py +6 -0
- rapidfireai-0.9.9.dist-info/METADATA +242 -0
- rapidfireai-0.9.9.dist-info/RECORD +318 -0
- rapidfireai-0.9.9.dist-info/entry_points.txt +2 -0
- rapidfireai-0.0.1.dist-info/METADATA +0 -37
- rapidfireai-0.0.1.dist-info/RECORD +0 -6
- {rapidfireai-0.0.1.dist-info → rapidfireai-0.9.9.dist-info}/WHEEL +0 -0
- {rapidfireai-0.0.1.dist-info → rapidfireai-0.9.9.dist-info}/licenses/LICENSE +0 -0
- {rapidfireai-0.0.1.dist-info → rapidfireai-0.9.9.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,469 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import gc
|
|
3
|
+
import threading
|
|
4
|
+
from multiprocessing import Lock, Manager
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from rapidfireai.utils.constants import SHM_MIN_FREE_SPACE, SHM_WARN_THRESHOLD, SHMObjectType
|
|
9
|
+
from rapidfireai.utils.exceptions import InsufficientSharedMemoryException
|
|
10
|
+
from rapidfireai.utils.logging import RFLogger
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _get_shm_usage():
|
|
14
|
+
"""Get shared memory storage usage information in GiB."""
|
|
15
|
+
import shutil
|
|
16
|
+
|
|
17
|
+
stat = shutil.disk_usage("/dev/shm")
|
|
18
|
+
total_gib = stat.total / (1024**3)
|
|
19
|
+
used_gib = stat.used / (1024**3)
|
|
20
|
+
free_gib = stat.free / (1024**3)
|
|
21
|
+
return {"total": total_gib, "used": used_gib, "free": free_gib, "percent_used": (stat.used / stat.total) * 100}
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _estimate_tensor_size_gib(obj):
|
|
25
|
+
"""Recursively estimate the size of tensors in a nested structure in GiB."""
|
|
26
|
+
if isinstance(obj, torch.Tensor):
|
|
27
|
+
# Calculate size in bytes: numel * element_size
|
|
28
|
+
size_bytes = obj.numel() * obj.element_size()
|
|
29
|
+
return size_bytes / (1024**3)
|
|
30
|
+
elif isinstance(obj, dict):
|
|
31
|
+
return sum(_estimate_tensor_size_gib(v) for v in obj.values())
|
|
32
|
+
elif isinstance(obj, (list, tuple)):
|
|
33
|
+
return sum(_estimate_tensor_size_gib(item) for item in obj)
|
|
34
|
+
else:
|
|
35
|
+
return 0.0
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _verify_sufficient_model_size(model: torch.nn.Module | None, logger: RFLogger):
|
|
39
|
+
# Check available storage space in /dev/shm
|
|
40
|
+
shm_info = None
|
|
41
|
+
model_size_gib = 0.0
|
|
42
|
+
try:
|
|
43
|
+
shm_info = _get_shm_usage()
|
|
44
|
+
free_gib = shm_info["free"]
|
|
45
|
+
total_gib = shm_info["total"]
|
|
46
|
+
percent_used = shm_info["percent_used"]
|
|
47
|
+
|
|
48
|
+
# Estimate the size of the model to be saved
|
|
49
|
+
model_size_gib = 0.0
|
|
50
|
+
if model is not None and not isinstance(model, str):
|
|
51
|
+
# Estimate parameters size
|
|
52
|
+
for param in model.parameters():
|
|
53
|
+
if param.data is not None:
|
|
54
|
+
model_size_gib += _estimate_tensor_size_gib(param.data)
|
|
55
|
+
|
|
56
|
+
# Estimate buffers size
|
|
57
|
+
for buffer in model.buffers():
|
|
58
|
+
if buffer is not None:
|
|
59
|
+
model_size_gib += _estimate_tensor_size_gib(buffer)
|
|
60
|
+
except Exception as e:
|
|
61
|
+
logger.warning(f"Could not check shared memory space at /dev/shm: {e}")
|
|
62
|
+
|
|
63
|
+
if shm_info and model_size_gib > 0.0:
|
|
64
|
+
# Warn if usage is high
|
|
65
|
+
if percent_used > SHM_WARN_THRESHOLD:
|
|
66
|
+
logger.warning(
|
|
67
|
+
f"Shared memory usage is high: {percent_used:.1f}%. Available space: {free_gib:.2f}/{total_gib:.2f} GiB"
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
# Check if at least SHM_MIN_FREE_SPACE GiB will be left after saving the model
|
|
71
|
+
if free_gib - model_size_gib < SHM_MIN_FREE_SPACE:
|
|
72
|
+
raise InsufficientSharedMemoryException(
|
|
73
|
+
f"Insufficient shared memory space: {free_gib:.2f} GiB available, model size: "
|
|
74
|
+
f"{model_size_gib:.2f} GiB, need at least {SHM_MIN_FREE_SPACE} GiB remaining after save"
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
return True
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _verify_sufficient_ref_state_dict_size(ref_state_dict: dict, logger: RFLogger):
|
|
81
|
+
# Check available storage space in /dev/shm
|
|
82
|
+
shm_info = None
|
|
83
|
+
state_dict_size_gib = 0.0
|
|
84
|
+
try:
|
|
85
|
+
shm_info = _get_shm_usage()
|
|
86
|
+
free_gib = shm_info["free"]
|
|
87
|
+
total_gib = shm_info["total"]
|
|
88
|
+
percent_used = shm_info["percent_used"]
|
|
89
|
+
|
|
90
|
+
# Estimate the size of the state dict to be saved
|
|
91
|
+
state_dict_size_gib = _estimate_tensor_size_gib(ref_state_dict)
|
|
92
|
+
|
|
93
|
+
except Exception as e:
|
|
94
|
+
logger.warning(f"Could not check shared memory space at /dev/shm: {e}")
|
|
95
|
+
|
|
96
|
+
if shm_info and state_dict_size_gib > 0.0:
|
|
97
|
+
# Warn if usage is high
|
|
98
|
+
if percent_used > SHM_WARN_THRESHOLD:
|
|
99
|
+
logger.warning(
|
|
100
|
+
f"Shared memory usage is high: {percent_used:.1f}%. Available space: {free_gib:.2f}/{total_gib:.2f} GiB"
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
# Check if at least SHM_MIN_FREE_SPACE GiB will be left after saving the state dict
|
|
104
|
+
if free_gib - state_dict_size_gib < SHM_MIN_FREE_SPACE:
|
|
105
|
+
raise InsufficientSharedMemoryException(
|
|
106
|
+
f"Insufficient shared memory space: {free_gib:.2f} GiB available, state dict size: "
|
|
107
|
+
f"{state_dict_size_gib:.2f} GiB, need at least {SHM_MIN_FREE_SPACE} GiB remaining after save"
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
return True
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class SharedMemoryManager:
|
|
114
|
+
"""Manages PyTorch models and checkpoints in shared memory across multiple processes."""
|
|
115
|
+
|
|
116
|
+
def __init__(self, name: str, registry=None, multiprocess_lock=None):
|
|
117
|
+
"""Initialize the shared memory manager with process-safe registry and locks"""
|
|
118
|
+
# initialize registry
|
|
119
|
+
if registry is None:
|
|
120
|
+
self._manager = Manager()
|
|
121
|
+
self._registry = self._manager.dict()
|
|
122
|
+
else:
|
|
123
|
+
self._registry = registry
|
|
124
|
+
|
|
125
|
+
# initialize multiprocess lock
|
|
126
|
+
if multiprocess_lock is None:
|
|
127
|
+
self._process_lock = self._manager.Lock()
|
|
128
|
+
else:
|
|
129
|
+
self._process_lock = multiprocess_lock
|
|
130
|
+
|
|
131
|
+
# initialize thread lock for operations within a single process
|
|
132
|
+
self._thread_lock = threading.Lock()
|
|
133
|
+
|
|
134
|
+
self.logger = RFLogger().create_logger(name)
|
|
135
|
+
|
|
136
|
+
# shared memory operations
|
|
137
|
+
def _safe_tensor_to_shared_memory(self, tensor: torch.Tensor | None) -> torch.Tensor | None:
|
|
138
|
+
"""Safely convert a tensor to shared memory format"""
|
|
139
|
+
if tensor is None:
|
|
140
|
+
return None
|
|
141
|
+
tensor = tensor.cpu()
|
|
142
|
+
tensor = tensor.detach().contiguous().clone()
|
|
143
|
+
tensor.share_memory_()
|
|
144
|
+
|
|
145
|
+
return tensor
|
|
146
|
+
|
|
147
|
+
def _move_tensors_to_shared_memory(self, obj):
|
|
148
|
+
"""Recursively move all tensors in a nested structure to shared memory"""
|
|
149
|
+
if isinstance(obj, torch.Tensor):
|
|
150
|
+
obj.share_memory_()
|
|
151
|
+
return obj
|
|
152
|
+
elif isinstance(obj, dict):
|
|
153
|
+
return {k: self._move_tensors_to_shared_memory(v) for k, v in obj.items()}
|
|
154
|
+
elif isinstance(obj, (list, tuple)):
|
|
155
|
+
return type(obj)(self._move_tensors_to_shared_memory(item) for item in obj)
|
|
156
|
+
else:
|
|
157
|
+
return obj
|
|
158
|
+
|
|
159
|
+
def _move_model_to_shared_memory(self, model):
|
|
160
|
+
"""Move model to shared memory with proper BitsAndBytes handling"""
|
|
161
|
+
model = model.cpu()
|
|
162
|
+
for _, param in model.named_parameters():
|
|
163
|
+
if param.data is not None:
|
|
164
|
+
param.data = self._safe_tensor_to_shared_memory(param.data)
|
|
165
|
+
|
|
166
|
+
for name, buffer in model.named_buffers():
|
|
167
|
+
if isinstance(buffer, torch.Tensor) and buffer is not None:
|
|
168
|
+
parent_module = model
|
|
169
|
+
attr_path = name.split(".")
|
|
170
|
+
|
|
171
|
+
for attr in attr_path[:-1]:
|
|
172
|
+
parent_module = getattr(parent_module, attr)
|
|
173
|
+
|
|
174
|
+
shared_buffer = self._safe_tensor_to_shared_memory(buffer)
|
|
175
|
+
setattr(parent_module, attr_path[-1], shared_buffer)
|
|
176
|
+
|
|
177
|
+
bnb_modules = {}
|
|
178
|
+
|
|
179
|
+
for name, module in model.named_modules():
|
|
180
|
+
if not hasattr(module, "weight"):
|
|
181
|
+
continue
|
|
182
|
+
|
|
183
|
+
import bitsandbytes as bnb
|
|
184
|
+
|
|
185
|
+
bnb_layer_types = [bnb.nn.Linear4bit, bnb.nn.LinearFP4, bnb.nn.LinearNF4, bnb.nn.Params4bit]
|
|
186
|
+
|
|
187
|
+
is_bnb_layer = any(isinstance(module, layer_type) for layer_type in bnb_layer_types)
|
|
188
|
+
|
|
189
|
+
if is_bnb_layer and hasattr(module, "weight"):
|
|
190
|
+
bnb_attrs = {}
|
|
191
|
+
weight = module.weight
|
|
192
|
+
|
|
193
|
+
if hasattr(weight, "data") and weight.data is not None:
|
|
194
|
+
weight.data = self._safe_tensor_to_shared_memory(weight.data)
|
|
195
|
+
|
|
196
|
+
if hasattr(weight, "quant_state") and weight.quant_state is not None:
|
|
197
|
+
quant_state = weight.quant_state
|
|
198
|
+
bnb_attrs["quant_state_data"] = {}
|
|
199
|
+
|
|
200
|
+
for attr_name in dir(quant_state):
|
|
201
|
+
if not attr_name.startswith("_") and hasattr(quant_state, attr_name):
|
|
202
|
+
attr_val = getattr(quant_state, attr_name)
|
|
203
|
+
|
|
204
|
+
if isinstance(attr_val, torch.Tensor):
|
|
205
|
+
bnb_attrs["quant_state_data"][attr_name] = self._safe_tensor_to_shared_memory(attr_val)
|
|
206
|
+
elif not callable(attr_val):
|
|
207
|
+
bnb_attrs["quant_state_data"][attr_name] = attr_val
|
|
208
|
+
|
|
209
|
+
if hasattr(quant_state, "state2") and quant_state.state2 is not None:
|
|
210
|
+
state2 = quant_state.state2
|
|
211
|
+
bnb_attrs["state2_data"] = {}
|
|
212
|
+
for attr_name in dir(state2):
|
|
213
|
+
if not attr_name.startswith("_") and hasattr(state2, attr_name):
|
|
214
|
+
attr_val = getattr(state2, attr_name)
|
|
215
|
+
if isinstance(attr_val, torch.Tensor):
|
|
216
|
+
bnb_attrs["state2_data"][attr_name] = self._safe_tensor_to_shared_memory(attr_val)
|
|
217
|
+
elif not callable(attr_val):
|
|
218
|
+
bnb_attrs["state2_data"][attr_name] = attr_val
|
|
219
|
+
|
|
220
|
+
bnb_attrs["quant_state_class"] = type(quant_state).__name__
|
|
221
|
+
|
|
222
|
+
weight_attrs = ["compress_statistics", "quant_type", "blocksize", "bnb_quantized"]
|
|
223
|
+
for attr in weight_attrs:
|
|
224
|
+
if hasattr(weight, attr):
|
|
225
|
+
attr_val = getattr(weight, attr)
|
|
226
|
+
if isinstance(attr_val, torch.Tensor):
|
|
227
|
+
attr_val = self._safe_tensor_to_shared_memory(attr_val)
|
|
228
|
+
bnb_attrs[attr] = attr_val
|
|
229
|
+
|
|
230
|
+
bnb_attrs["weight_class"] = type(weight).__name__
|
|
231
|
+
bnb_modules[name] = bnb_attrs
|
|
232
|
+
|
|
233
|
+
return model, bnb_modules
|
|
234
|
+
|
|
235
|
+
# model object operations
|
|
236
|
+
def _save_full_model(self, model_id: str, model_data: dict, model_object_type: SHMObjectType):
|
|
237
|
+
"""Save the full model in shared memory. model_id can be either run_id or name of a base model"""
|
|
238
|
+
with self._process_lock if self._process_lock else self._thread_lock:
|
|
239
|
+
if model_id in self._registry:
|
|
240
|
+
self.logger.debug(f"Model {model_id} already exists in shared memory. Skipping save.")
|
|
241
|
+
return
|
|
242
|
+
|
|
243
|
+
# verify sufficient shared memory space before saving model
|
|
244
|
+
_verify_sufficient_model_size(model_data[model_object_type], self.logger)
|
|
245
|
+
|
|
246
|
+
# create model entry in registry
|
|
247
|
+
if model_id not in self._registry:
|
|
248
|
+
self._registry[model_id] = {model_object_type: {}}
|
|
249
|
+
|
|
250
|
+
# move model to shared memory
|
|
251
|
+
model_cpu = model_data[model_object_type]
|
|
252
|
+
tokenizer = model_data["tokenizer"]
|
|
253
|
+
model, bnb_modules = self._move_model_to_shared_memory(model_cpu)
|
|
254
|
+
shared_model = {
|
|
255
|
+
model_object_type: model,
|
|
256
|
+
"tokenizer": tokenizer,
|
|
257
|
+
"bnb_modules": self._move_tensors_to_shared_memory(bnb_modules),
|
|
258
|
+
}
|
|
259
|
+
model_entry = dict(self._registry[model_id])
|
|
260
|
+
model_entry[model_object_type] = shared_model
|
|
261
|
+
self._registry[model_id] = model_entry
|
|
262
|
+
|
|
263
|
+
self.logger.debug(f"Saved {model_object_type.value} for run {model_id}")
|
|
264
|
+
|
|
265
|
+
def _save_ref_state_dict(self, model_id: str, ref_state_dict: dict):
|
|
266
|
+
"""Save the reference state dict."""
|
|
267
|
+
with self._thread_lock:
|
|
268
|
+
# verify sufficient shared memory space before saving ref_state_dict
|
|
269
|
+
_verify_sufficient_ref_state_dict_size(ref_state_dict, self.logger)
|
|
270
|
+
|
|
271
|
+
# create model entry in registry
|
|
272
|
+
if model_id not in self._registry:
|
|
273
|
+
self._registry[model_id] = {SHMObjectType.REF_STATE_DICT: {}}
|
|
274
|
+
|
|
275
|
+
# move ref_state_dict to shared memory
|
|
276
|
+
shared_ref_state_dict = self._move_tensors_to_shared_memory(ref_state_dict)
|
|
277
|
+
model_entry = dict(self._registry[model_id])
|
|
278
|
+
model_entry[SHMObjectType.REF_STATE_DICT] = shared_ref_state_dict
|
|
279
|
+
self._registry[model_id] = model_entry
|
|
280
|
+
|
|
281
|
+
self.logger.debug(f"Saved ref_state_dict for {model_id}")
|
|
282
|
+
|
|
283
|
+
def _update_checkpoints(self, model_id: str, checkpoint_updates: dict):
|
|
284
|
+
"""Update checkpoints in-place when possible, add new keys when needed."""
|
|
285
|
+
with self._thread_lock:
|
|
286
|
+
# create model entry in registry
|
|
287
|
+
if model_id not in self._registry:
|
|
288
|
+
self._registry[model_id] = {SHMObjectType.CHECKPOINTS: {}}
|
|
289
|
+
|
|
290
|
+
model_entry = self._registry[model_id]
|
|
291
|
+
if SHMObjectType.CHECKPOINTS not in model_entry:
|
|
292
|
+
model_entry[SHMObjectType.CHECKPOINTS] = {}
|
|
293
|
+
current_checkpoints = model_entry[SHMObjectType.CHECKPOINTS]
|
|
294
|
+
|
|
295
|
+
updates_made = {"in_place": 0, "new_keys": 0}
|
|
296
|
+
|
|
297
|
+
def update_nested_dict(current_dict, updates_dict, path=""):
|
|
298
|
+
for key, new_value in updates_dict.items():
|
|
299
|
+
current_path = f"{path}.{key}" if path else key
|
|
300
|
+
|
|
301
|
+
if key in current_dict:
|
|
302
|
+
current_value = current_dict[key]
|
|
303
|
+
|
|
304
|
+
if isinstance(new_value, torch.Tensor) and isinstance(current_value, torch.Tensor):
|
|
305
|
+
# In-place tensor update if shapes match
|
|
306
|
+
if (
|
|
307
|
+
current_value.shape == new_value.shape
|
|
308
|
+
and current_value.dtype == new_value.dtype
|
|
309
|
+
and current_value.is_shared()
|
|
310
|
+
):
|
|
311
|
+
current_value.copy_(new_value.cpu())
|
|
312
|
+
updates_made["in_place"] += 1
|
|
313
|
+
else:
|
|
314
|
+
# Need new shared tensor
|
|
315
|
+
new_shared = new_value.cpu().clone()
|
|
316
|
+
new_shared.share_memory_()
|
|
317
|
+
current_dict[key] = new_shared
|
|
318
|
+
updates_made["new_keys"] += 1
|
|
319
|
+
self.logger.debug(f"New tensor (shape/type change): {current_path}")
|
|
320
|
+
|
|
321
|
+
elif isinstance(new_value, dict) and isinstance(current_value, dict):
|
|
322
|
+
# Recursively update nested dicts
|
|
323
|
+
update_nested_dict(current_value, new_value, current_path)
|
|
324
|
+
|
|
325
|
+
else:
|
|
326
|
+
# Non-tensor value update
|
|
327
|
+
current_dict[key] = new_value
|
|
328
|
+
|
|
329
|
+
else:
|
|
330
|
+
# New key - add to shared memory
|
|
331
|
+
if isinstance(new_value, torch.Tensor):
|
|
332
|
+
new_shared = new_value.cpu().clone()
|
|
333
|
+
new_shared.share_memory_()
|
|
334
|
+
current_dict[key] = new_shared
|
|
335
|
+
updates_made["new_keys"] += 1
|
|
336
|
+
elif isinstance(new_value, dict):
|
|
337
|
+
# New nested dict
|
|
338
|
+
current_dict[key] = self._move_tensors_to_shared_memory(new_value)
|
|
339
|
+
updates_made["new_keys"] += 1
|
|
340
|
+
else:
|
|
341
|
+
# New non-tensor value
|
|
342
|
+
current_dict[key] = new_value
|
|
343
|
+
|
|
344
|
+
# Update the checkpoints
|
|
345
|
+
update_nested_dict(current_checkpoints, checkpoint_updates)
|
|
346
|
+
|
|
347
|
+
# Update the registry entry to ensure Manager sees changes
|
|
348
|
+
updated_entry = dict(model_entry)
|
|
349
|
+
updated_entry[SHMObjectType.CHECKPOINTS] = current_checkpoints
|
|
350
|
+
self._registry[model_id] = updated_entry
|
|
351
|
+
|
|
352
|
+
self.logger.debug(f"Checkpoint update:{updates_made['in_place']} in-place, {updates_made['new_keys']} new")
|
|
353
|
+
|
|
354
|
+
def get_shm_objects(self) -> tuple[dict, Lock]:
|
|
355
|
+
"""Get the shared registry and process lock"""
|
|
356
|
+
return self._registry, self._process_lock
|
|
357
|
+
|
|
358
|
+
def load_model_object(self, model_id: str, model_object_type: SHMObjectType):
|
|
359
|
+
"""Load a model object from shared memory."""
|
|
360
|
+
model_entry = self._registry.get(model_id)
|
|
361
|
+
if model_entry is None:
|
|
362
|
+
self.logger.warning(f"Model {model_id} not found in shared memory")
|
|
363
|
+
return None
|
|
364
|
+
model_obj = model_entry.get(model_object_type)
|
|
365
|
+
return model_obj
|
|
366
|
+
|
|
367
|
+
def save_model_object(self, model_id: str, model_object_type: SHMObjectType, model_object: dict):
|
|
368
|
+
"""Save a model object to shared memory."""
|
|
369
|
+
# save model object
|
|
370
|
+
if model_object_type in [SHMObjectType.BASE_MODEL, SHMObjectType.FULL_MODEL, SHMObjectType.REF_FULL_MODEL]:
|
|
371
|
+
self._save_full_model(model_id, model_object, model_object_type)
|
|
372
|
+
elif model_object_type == SHMObjectType.REF_STATE_DICT:
|
|
373
|
+
self._save_ref_state_dict(model_id, model_object)
|
|
374
|
+
elif model_object_type == SHMObjectType.CHECKPOINTS:
|
|
375
|
+
self._update_checkpoints(model_id, model_object)
|
|
376
|
+
|
|
377
|
+
def delete_model_object(self, model_id: str, base_model_name: str | None = None):
|
|
378
|
+
"""Delete model object from shared memory registry and clean up resources."""
|
|
379
|
+
with self._process_lock if self._process_lock else self._thread_lock:
|
|
380
|
+
if model_id not in self._registry:
|
|
381
|
+
self.logger.warning(f"Model '{model_id}' not found in shared memory during delete")
|
|
382
|
+
return
|
|
383
|
+
|
|
384
|
+
# remove checkpoints
|
|
385
|
+
# TODO: add code to save to disk before deleting
|
|
386
|
+
if (
|
|
387
|
+
SHMObjectType.CHECKPOINTS in self._registry[model_id]
|
|
388
|
+
and self._registry[model_id][SHMObjectType.CHECKPOINTS]
|
|
389
|
+
):
|
|
390
|
+
del self._registry[model_id][SHMObjectType.CHECKPOINTS]
|
|
391
|
+
self.logger.debug(f"Deleted checkpoints for model {model_id} from shared memory")
|
|
392
|
+
|
|
393
|
+
# remove full_model
|
|
394
|
+
# TODO: add code to save to disk before deleting
|
|
395
|
+
if (
|
|
396
|
+
SHMObjectType.FULL_MODEL in self._registry[model_id]
|
|
397
|
+
and self._registry[model_id][SHMObjectType.FULL_MODEL]
|
|
398
|
+
):
|
|
399
|
+
del self._registry[model_id][SHMObjectType.FULL_MODEL]
|
|
400
|
+
self.logger.debug(f"Deleted full_model for model {model_id} from shared memory")
|
|
401
|
+
|
|
402
|
+
# remove ref_state_dict
|
|
403
|
+
if (
|
|
404
|
+
SHMObjectType.REF_STATE_DICT in self._registry[model_id]
|
|
405
|
+
and self._registry[model_id][SHMObjectType.REF_STATE_DICT]
|
|
406
|
+
):
|
|
407
|
+
del self._registry[model_id][SHMObjectType.REF_STATE_DICT]
|
|
408
|
+
self.logger.debug(f"Deleted ref_state_dict for model {model_id} from shared memory")
|
|
409
|
+
|
|
410
|
+
# remove ref_full_model
|
|
411
|
+
if (
|
|
412
|
+
SHMObjectType.REF_FULL_MODEL in self._registry[model_id]
|
|
413
|
+
and self._registry[model_id][SHMObjectType.REF_FULL_MODEL]
|
|
414
|
+
):
|
|
415
|
+
del self._registry[model_id][SHMObjectType.REF_FULL_MODEL]
|
|
416
|
+
self.logger.debug(f"Deleted ref_full_model for model {model_id} from shared memory")
|
|
417
|
+
|
|
418
|
+
# remove shared objects (entire registry entry is deleted for base_model, not just SHMObjectType.BASE_MODEL key)
|
|
419
|
+
if base_model_name and base_model_name in self._registry:
|
|
420
|
+
del self._registry[base_model_name]
|
|
421
|
+
self.logger.debug(f"Deleted base_model for model {model_id} from shared memory")
|
|
422
|
+
|
|
423
|
+
# remove registry entry
|
|
424
|
+
del self._registry[model_id]
|
|
425
|
+
self.logger.debug(f"Deleted model registry entry for {model_id} from shared memory")
|
|
426
|
+
|
|
427
|
+
# Force garbage collection
|
|
428
|
+
gc.collect()
|
|
429
|
+
if torch.cuda.is_available():
|
|
430
|
+
torch.cuda.empty_cache()
|
|
431
|
+
|
|
432
|
+
self.logger.debug("Force garbage collection and empty cache")
|
|
433
|
+
|
|
434
|
+
def create_warm_start_checkpoint(self, model_id: str, warm_started_from: str):
|
|
435
|
+
"""Copy warm start checkpoint from model_id to warm_started_from"""
|
|
436
|
+
with self._thread_lock:
|
|
437
|
+
if warm_started_from not in self._registry:
|
|
438
|
+
raise KeyError(f"Run '{warm_started_from}' not found in shared memory")
|
|
439
|
+
|
|
440
|
+
# create model entry in registry
|
|
441
|
+
if model_id not in self._registry:
|
|
442
|
+
self._registry[model_id] = {
|
|
443
|
+
SHMObjectType.FULL_MODEL: {},
|
|
444
|
+
SHMObjectType.REF_STATE_DICT: {},
|
|
445
|
+
SHMObjectType.CHECKPOINTS: {},
|
|
446
|
+
}
|
|
447
|
+
|
|
448
|
+
model_entry = dict(self._registry[model_id])
|
|
449
|
+
model_entry[SHMObjectType.FULL_MODEL] = copy.deepcopy(
|
|
450
|
+
dict(self._registry[warm_started_from])[SHMObjectType.FULL_MODEL]
|
|
451
|
+
)
|
|
452
|
+
model_entry[SHMObjectType.REF_STATE_DICT] = copy.deepcopy(
|
|
453
|
+
dict(self._registry[warm_started_from])[SHMObjectType.REF_STATE_DICT]
|
|
454
|
+
)
|
|
455
|
+
model_entry[SHMObjectType.CHECKPOINTS] = copy.deepcopy(
|
|
456
|
+
dict(self._registry[warm_started_from])[SHMObjectType.CHECKPOINTS]
|
|
457
|
+
)
|
|
458
|
+
self._registry[model_id] = model_entry
|
|
459
|
+
self.logger.debug(f"Copied warm start checkpoint from {warm_started_from} to {model_id}")
|
|
460
|
+
|
|
461
|
+
def list_models(self):
|
|
462
|
+
"""Get list of all model IDs currently in shared memory."""
|
|
463
|
+
with self._process_lock if self._process_lock else self._thread_lock:
|
|
464
|
+
return list(self._registry.keys())
|
|
465
|
+
|
|
466
|
+
def model_exists(self, model_id: str):
|
|
467
|
+
"""Check if a model exists in shared memory."""
|
|
468
|
+
with self._process_lock if self._process_lock else self._thread_lock:
|
|
469
|
+
return model_id in self._registry
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
"""This module contains the TrainerConfig class which is responsible for configuring the trainer."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any, Callable, Optional
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class TrainerConfig:
|
|
11
|
+
"""Trainer configuration"""
|
|
12
|
+
|
|
13
|
+
worker_id: int
|
|
14
|
+
run_id: int
|
|
15
|
+
mlflow_run_id: str
|
|
16
|
+
config_leaf: dict[str, Any]
|
|
17
|
+
total_steps: int
|
|
18
|
+
completed_steps: int
|
|
19
|
+
create_model_fn: Callable
|
|
20
|
+
train_dataset: torch.utils.data.Dataset
|
|
21
|
+
eval_dataset: Optional[torch.utils.data.Dataset]
|
|
22
|
+
warm_started_from: int | None
|
|
23
|
+
num_epochs_completed: int
|