rapidfireai 0.0.1__py3-none-any.whl → 0.9.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rapidfireai might be problematic. Click here for more details.

Files changed (320) hide show
  1. rapidfireai/__init__.py +11 -5
  2. rapidfireai/automl/__init__.py +20 -0
  3. rapidfireai/automl/base.py +48 -0
  4. rapidfireai/automl/datatypes.py +42 -0
  5. rapidfireai/automl/grid_search.py +125 -0
  6. rapidfireai/automl/model_config.py +102 -0
  7. rapidfireai/automl/random_search.py +145 -0
  8. rapidfireai/backend/__init__.py +0 -0
  9. rapidfireai/backend/chunks.py +63 -0
  10. rapidfireai/backend/controller.py +637 -0
  11. rapidfireai/backend/scheduler.py +137 -0
  12. rapidfireai/backend/worker.py +272 -0
  13. rapidfireai/cli.py +380 -0
  14. rapidfireai/db/__init__.py +0 -0
  15. rapidfireai/db/db_interface.py +135 -0
  16. rapidfireai/db/rf_db.py +694 -0
  17. rapidfireai/db/tables.sql +64 -0
  18. rapidfireai/dispatcher/dispatcher.py +391 -0
  19. rapidfireai/dispatcher/gunicorn.conf.py +25 -0
  20. rapidfireai/experiment.py +168 -0
  21. rapidfireai/frontend/build/asset-manifest.json +276 -0
  22. rapidfireai/frontend/build/favicon.ico +0 -0
  23. rapidfireai/frontend/build/index.html +1 -0
  24. rapidfireai/frontend/build/manifest.json +15 -0
  25. rapidfireai/frontend/build/pdf.worker.js +1 -0
  26. rapidfireai/frontend/build/report.html +39 -0
  27. rapidfireai/frontend/build/static/css/1482.3b7bf531.chunk.css +1 -0
  28. rapidfireai/frontend/build/static/css/2730.3f8937ff.chunk.css +1 -0
  29. rapidfireai/frontend/build/static/css/318.0def90a7.css +7 -0
  30. rapidfireai/frontend/build/static/css/4762.9b7b71f7.chunk.css +1 -0
  31. rapidfireai/frontend/build/static/css/4950.487ecc8b.chunk.css +1 -0
  32. rapidfireai/frontend/build/static/css/5170.2574ce9d.chunk.css +1 -0
  33. rapidfireai/frontend/build/static/css/6121.4d541986.chunk.css +1 -0
  34. rapidfireai/frontend/build/static/css/6343.dd6979f2.chunk.css +1 -0
  35. rapidfireai/frontend/build/static/css/6534.433c213f.chunk.css +1 -0
  36. rapidfireai/frontend/build/static/css/6920.ffac4b2a.css +2 -0
  37. rapidfireai/frontend/build/static/css/7246.bf2f0c87.css +9 -0
  38. rapidfireai/frontend/build/static/css/7367.dd6979f2.chunk.css +1 -0
  39. rapidfireai/frontend/build/static/css/8690.05d081e5.chunk.css +1 -0
  40. rapidfireai/frontend/build/static/css/9531.d0910d3c.chunk.css +1 -0
  41. rapidfireai/frontend/build/static/css/9780.363e4943.chunk.css +1 -0
  42. rapidfireai/frontend/build/static/css/main~d91a9049.c0be472c.css +1 -0
  43. rapidfireai/frontend/build/static/js/1000.e5ed264b.chunk.js +1 -0
  44. rapidfireai/frontend/build/static/js/1012.ac98ab59.chunk.js +1 -0
  45. rapidfireai/frontend/build/static/js/1079.6c13ac0d.js +1 -0
  46. rapidfireai/frontend/build/static/js/110.9059f3b8.chunk.js +1 -0
  47. rapidfireai/frontend/build/static/js/1142.872d0010.chunk.js +1 -0
  48. rapidfireai/frontend/build/static/js/1167.9a6da14c.chunk.js +1 -0
  49. rapidfireai/frontend/build/static/js/1248.60890b4f.chunk.js +1 -0
  50. rapidfireai/frontend/build/static/js/1262.83dc7673.chunk.js +1 -0
  51. rapidfireai/frontend/build/static/js/1273.56da3e13.chunk.js +2 -0
  52. rapidfireai/frontend/build/static/js/1273.56da3e13.chunk.js.LICENSE.txt +9 -0
  53. rapidfireai/frontend/build/static/js/1303.7d19305c.chunk.js +1 -0
  54. rapidfireai/frontend/build/static/js/1351.45076ff3.chunk.js +1 -0
  55. rapidfireai/frontend/build/static/js/1355.b896a592.js +1 -0
  56. rapidfireai/frontend/build/static/js/1357.02c46a02.chunk.js +1 -0
  57. rapidfireai/frontend/build/static/js/1470.c51d60c6.chunk.js +1 -0
  58. rapidfireai/frontend/build/static/js/1482.23b74f50.chunk.js +1 -0
  59. rapidfireai/frontend/build/static/js/1500.19799d8d.chunk.js +1 -0
  60. rapidfireai/frontend/build/static/js/1648.d3b9edc7.chunk.js +1 -0
  61. rapidfireai/frontend/build/static/js/1860.7d96e3f9.chunk.js +1 -0
  62. rapidfireai/frontend/build/static/js/1909.5b1d9ff4.chunk.js +1 -0
  63. rapidfireai/frontend/build/static/js/1928.44245110.chunk.js +2 -0
  64. rapidfireai/frontend/build/static/js/1928.44245110.chunk.js.LICENSE.txt +11 -0
  65. rapidfireai/frontend/build/static/js/1933.deba26ca.chunk.js +1 -0
  66. rapidfireai/frontend/build/static/js/21.aac92802.chunk.js +1 -0
  67. rapidfireai/frontend/build/static/js/2103.0ca12071.chunk.js +1 -0
  68. rapidfireai/frontend/build/static/js/2258.b3b8fab4.chunk.js +1 -0
  69. rapidfireai/frontend/build/static/js/2289.9ad51e87.chunk.js +1 -0
  70. rapidfireai/frontend/build/static/js/2323.7dd927d7.js +2 -0
  71. rapidfireai/frontend/build/static/js/2323.7dd927d7.js.LICENSE.txt +1 -0
  72. rapidfireai/frontend/build/static/js/2346.ed99ca72.chunk.js +1 -0
  73. rapidfireai/frontend/build/static/js/2386.0a660834.chunk.js +1 -0
  74. rapidfireai/frontend/build/static/js/2402.465048f9.chunk.js +1 -0
  75. rapidfireai/frontend/build/static/js/243.5a83bbca.chunk.js +1 -0
  76. rapidfireai/frontend/build/static/js/2589.68571e16.js +1 -0
  77. rapidfireai/frontend/build/static/js/2647.65092bab.chunk.js +1 -0
  78. rapidfireai/frontend/build/static/js/2691.65d4a4e7.js +1 -0
  79. rapidfireai/frontend/build/static/js/2730.b38dd6f3.chunk.js +1 -0
  80. rapidfireai/frontend/build/static/js/2746.ef752da4.chunk.js +1 -0
  81. rapidfireai/frontend/build/static/js/2779.580d4491.chunk.js +1 -0
  82. rapidfireai/frontend/build/static/js/2799.fe5993b2.chunk.js +1 -0
  83. rapidfireai/frontend/build/static/js/2844.9708db79.chunk.js +2 -0
  84. rapidfireai/frontend/build/static/js/2844.9708db79.chunk.js.LICENSE.txt +21 -0
  85. rapidfireai/frontend/build/static/js/2901.ee0c606b.chunk.js +1 -0
  86. rapidfireai/frontend/build/static/js/2932.7cc0689b.chunk.js +2 -0
  87. rapidfireai/frontend/build/static/js/2932.7cc0689b.chunk.js.LICENSE.txt +6 -0
  88. rapidfireai/frontend/build/static/js/2956.a393c8cc.chunk.js +1 -0
  89. rapidfireai/frontend/build/static/js/2972.679bed05.chunk.js +1 -0
  90. rapidfireai/frontend/build/static/js/2985.7e51cdfa.chunk.js +2 -0
  91. rapidfireai/frontend/build/static/js/2985.7e51cdfa.chunk.js.LICENSE.txt +51 -0
  92. rapidfireai/frontend/build/static/js/3093.488df653.js +1 -0
  93. rapidfireai/frontend/build/static/js/3145.66ee61b9.js +1 -0
  94. rapidfireai/frontend/build/static/js/3170.a22f966a.chunk.js +2 -0
  95. rapidfireai/frontend/build/static/js/3170.a22f966a.chunk.js.LICENSE.txt +21 -0
  96. rapidfireai/frontend/build/static/js/3307.f6fb258c.chunk.js +1 -0
  97. rapidfireai/frontend/build/static/js/3325.d5b03d65.js +1 -0
  98. rapidfireai/frontend/build/static/js/3334.2d6704df.chunk.js +2 -0
  99. rapidfireai/frontend/build/static/js/3334.2d6704df.chunk.js.LICENSE.txt +6 -0
  100. rapidfireai/frontend/build/static/js/3387.bb8edad3.chunk.js +1 -0
  101. rapidfireai/frontend/build/static/js/3448.438e6579.chunk.js +1 -0
  102. rapidfireai/frontend/build/static/js/3460.735eea87.chunk.js +1 -0
  103. rapidfireai/frontend/build/static/js/3505.7fd3921a.js +2 -0
  104. rapidfireai/frontend/build/static/js/3505.7fd3921a.js.LICENSE.txt +9 -0
  105. rapidfireai/frontend/build/static/js/3510.cd167a00.js +2 -0
  106. rapidfireai/frontend/build/static/js/3510.cd167a00.js.LICENSE.txt +18 -0
  107. rapidfireai/frontend/build/static/js/3563.cc828e19.chunk.js +1 -0
  108. rapidfireai/frontend/build/static/js/359.08960b84.chunk.js +2 -0
  109. rapidfireai/frontend/build/static/js/359.08960b84.chunk.js.LICENSE.txt +4 -0
  110. rapidfireai/frontend/build/static/js/3608.403b4b79.chunk.js +1 -0
  111. rapidfireai/frontend/build/static/js/3652.cb8add7f.js +1 -0
  112. rapidfireai/frontend/build/static/js/3775.5230b157.chunk.js +1 -0
  113. rapidfireai/frontend/build/static/js/3817.53555d18.js +2 -0
  114. rapidfireai/frontend/build/static/js/3817.53555d18.js.LICENSE.txt +18 -0
  115. rapidfireai/frontend/build/static/js/3835.d9946ff9.chunk.js +1 -0
  116. rapidfireai/frontend/build/static/js/3964.874f0297.chunk.js +1 -0
  117. rapidfireai/frontend/build/static/js/3968.275cbc3d.chunk.js +1 -0
  118. rapidfireai/frontend/build/static/js/3999.765cbd82.chunk.js +1 -0
  119. rapidfireai/frontend/build/static/js/4020.4452c046.chunk.js +1 -0
  120. rapidfireai/frontend/build/static/js/4138.2f6f6d9f.js +1 -0
  121. rapidfireai/frontend/build/static/js/4160.f424554c.js +1 -0
  122. rapidfireai/frontend/build/static/js/4180.50cea095.chunk.js +1 -0
  123. rapidfireai/frontend/build/static/js/4221.b0bba3f5.chunk.js +1 -0
  124. rapidfireai/frontend/build/static/js/4250.5bb49278.chunk.js +1 -0
  125. rapidfireai/frontend/build/static/js/4297.15777d8f.chunk.js +1 -0
  126. rapidfireai/frontend/build/static/js/4349.c965f2de.js +2 -0
  127. rapidfireai/frontend/build/static/js/4349.c965f2de.js.LICENSE.txt +1 -0
  128. rapidfireai/frontend/build/static/js/4484.4cbe5e7f.js +2 -0
  129. rapidfireai/frontend/build/static/js/4484.4cbe5e7f.js.LICENSE.txt +10 -0
  130. rapidfireai/frontend/build/static/js/4578.a8124588.js +1 -0
  131. rapidfireai/frontend/build/static/js/4596.89a97480.js +1 -0
  132. rapidfireai/frontend/build/static/js/4748.566f435a.chunk.js +1 -0
  133. rapidfireai/frontend/build/static/js/4762.928e8a90.chunk.js +1 -0
  134. rapidfireai/frontend/build/static/js/4768.7945be63.js +2 -0
  135. rapidfireai/frontend/build/static/js/4768.7945be63.js.LICENSE.txt +1 -0
  136. rapidfireai/frontend/build/static/js/4804.26b50dd4.chunk.js +1 -0
  137. rapidfireai/frontend/build/static/js/4850.62390a45.chunk.js +1 -0
  138. rapidfireai/frontend/build/static/js/4862.a0ccb221.chunk.js +1 -0
  139. rapidfireai/frontend/build/static/js/491.5dc8ed40.chunk.js +1 -0
  140. rapidfireai/frontend/build/static/js/492.9262f038.chunk.js +2 -0
  141. rapidfireai/frontend/build/static/js/492.9262f038.chunk.js.LICENSE.txt +6 -0
  142. rapidfireai/frontend/build/static/js/4943.6d345fd3.chunk.js +1 -0
  143. rapidfireai/frontend/build/static/js/4950.bc182e62.chunk.js +1 -0
  144. rapidfireai/frontend/build/static/js/5042.d4f0c65a.chunk.js +2 -0
  145. rapidfireai/frontend/build/static/js/5042.d4f0c65a.chunk.js.LICENSE.txt +6 -0
  146. rapidfireai/frontend/build/static/js/5170.0065e96f.chunk.js +1 -0
  147. rapidfireai/frontend/build/static/js/5222.35c74a52.js +2 -0
  148. rapidfireai/frontend/build/static/js/5222.35c74a52.js.LICENSE.txt +10 -0
  149. rapidfireai/frontend/build/static/js/5223.3224f019.chunk.js +2 -0
  150. rapidfireai/frontend/build/static/js/5223.3224f019.chunk.js.LICENSE.txt +3 -0
  151. rapidfireai/frontend/build/static/js/5229.7dd42316.chunk.js +1 -0
  152. rapidfireai/frontend/build/static/js/5286.4c1ad26b.js +1 -0
  153. rapidfireai/frontend/build/static/js/5486.21cff711.chunk.js +1 -0
  154. rapidfireai/frontend/build/static/js/5526.7b368956.chunk.js +1 -0
  155. rapidfireai/frontend/build/static/js/5605.1ee4d87b.chunk.js +1 -0
  156. rapidfireai/frontend/build/static/js/5682.40b42d8b.chunk.js +1 -0
  157. rapidfireai/frontend/build/static/js/5794.9433d867.chunk.js +1 -0
  158. rapidfireai/frontend/build/static/js/5826.38a56e8c.chunk.js +2 -0
  159. rapidfireai/frontend/build/static/js/5826.38a56e8c.chunk.js.LICENSE.txt +1 -0
  160. rapidfireai/frontend/build/static/js/5862.50f42a0b.js +1 -0
  161. rapidfireai/frontend/build/static/js/5895.e26742f1.chunk.js +1 -0
  162. rapidfireai/frontend/build/static/js/5919.edd4a5cf.chunk.js +1 -0
  163. rapidfireai/frontend/build/static/js/598.a0e792ae.js +1 -0
  164. rapidfireai/frontend/build/static/js/6058.74162bf9.chunk.js +1 -0
  165. rapidfireai/frontend/build/static/js/618.06051134.chunk.js +2 -0
  166. rapidfireai/frontend/build/static/js/618.06051134.chunk.js.LICENSE.txt +21 -0
  167. rapidfireai/frontend/build/static/js/6335.9fca442d.chunk.js +1 -0
  168. rapidfireai/frontend/build/static/js/6336.e05e1154.chunk.js +1 -0
  169. rapidfireai/frontend/build/static/js/6343.2bcd28ff.chunk.js +1 -0
  170. rapidfireai/frontend/build/static/js/6363.a319b8f2.chunk.js +1 -0
  171. rapidfireai/frontend/build/static/js/6478.344abf25.chunk.js +1 -0
  172. rapidfireai/frontend/build/static/js/6504.1c004564.js +1 -0
  173. rapidfireai/frontend/build/static/js/6534.ec7e149b.chunk.js +1 -0
  174. rapidfireai/frontend/build/static/js/6715.55a5c19c.chunk.js +1 -0
  175. rapidfireai/frontend/build/static/js/6756.e6cb993c.chunk.js +2 -0
  176. rapidfireai/frontend/build/static/js/6756.e6cb993c.chunk.js.LICENSE.txt +10 -0
  177. rapidfireai/frontend/build/static/js/6762.acfde9fd.chunk.js +2 -0
  178. rapidfireai/frontend/build/static/js/6762.acfde9fd.chunk.js.LICENSE.txt +19 -0
  179. rapidfireai/frontend/build/static/js/6846.67103d0e.chunk.js +1 -0
  180. rapidfireai/frontend/build/static/js/6861.34cf0198.chunk.js +1 -0
  181. rapidfireai/frontend/build/static/js/6899.0eaf36a8.chunk.js +2 -0
  182. rapidfireai/frontend/build/static/js/6899.0eaf36a8.chunk.js.LICENSE.txt +5 -0
  183. rapidfireai/frontend/build/static/js/6933.8b564944.chunk.js +1 -0
  184. rapidfireai/frontend/build/static/js/699.d0437920.js +1 -0
  185. rapidfireai/frontend/build/static/js/7076.4182f63a.chunk.js +1 -0
  186. rapidfireai/frontend/build/static/js/7186.42ad86d5.chunk.js +1 -0
  187. rapidfireai/frontend/build/static/js/7248.a46635fd.js +1 -0
  188. rapidfireai/frontend/build/static/js/725.6b15a14a.chunk.js +1 -0
  189. rapidfireai/frontend/build/static/js/7266.3575539d.chunk.js +1 -0
  190. rapidfireai/frontend/build/static/js/7270.0a1e84fc.chunk.js +2 -0
  191. rapidfireai/frontend/build/static/js/7270.0a1e84fc.chunk.js.LICENSE.txt +6 -0
  192. rapidfireai/frontend/build/static/js/7367.7120474f.chunk.js +1 -0
  193. rapidfireai/frontend/build/static/js/7436.8e226055.js +1 -0
  194. rapidfireai/frontend/build/static/js/7504.ef223844.chunk.js +1 -0
  195. rapidfireai/frontend/build/static/js/7603.ee049fe3.chunk.js +1 -0
  196. rapidfireai/frontend/build/static/js/7670.2835b49a.chunk.js +2 -0
  197. rapidfireai/frontend/build/static/js/7670.2835b49a.chunk.js.LICENSE.txt +6 -0
  198. rapidfireai/frontend/build/static/js/7721.7390b3cc.chunk.js +1 -0
  199. rapidfireai/frontend/build/static/js/7731.5796cced.chunk.js +1 -0
  200. rapidfireai/frontend/build/static/js/775.660a5deb.chunk.js +2 -0
  201. rapidfireai/frontend/build/static/js/775.660a5deb.chunk.js.LICENSE.txt +6 -0
  202. rapidfireai/frontend/build/static/js/7832.7976a3e4.chunk.js +1 -0
  203. rapidfireai/frontend/build/static/js/7844.72cc2e81.chunk.js +1 -0
  204. rapidfireai/frontend/build/static/js/7948.48eab032.js +1 -0
  205. rapidfireai/frontend/build/static/js/7972.085079d4.chunk.js +2 -0
  206. rapidfireai/frontend/build/static/js/7972.085079d4.chunk.js.LICENSE.txt +6 -0
  207. rapidfireai/frontend/build/static/js/8017.a9e7dc5a.chunk.js +1 -0
  208. rapidfireai/frontend/build/static/js/8023.75f1f3df.js +2 -0
  209. rapidfireai/frontend/build/static/js/8023.75f1f3df.js.LICENSE.txt +41 -0
  210. rapidfireai/frontend/build/static/js/8123.b69db974.js +1 -0
  211. rapidfireai/frontend/build/static/js/813.065a87e5.chunk.js +1 -0
  212. rapidfireai/frontend/build/static/js/819.2056f122.chunk.js +2 -0
  213. rapidfireai/frontend/build/static/js/819.2056f122.chunk.js.LICENSE.txt +6 -0
  214. rapidfireai/frontend/build/static/js/8262.04bc17d1.chunk.js +1 -0
  215. rapidfireai/frontend/build/static/js/8300.75adcc4f.chunk.js +1 -0
  216. rapidfireai/frontend/build/static/js/8336.b1d3e764.chunk.js +1 -0
  217. rapidfireai/frontend/build/static/js/8365.26cf64ea.chunk.js +1 -0
  218. rapidfireai/frontend/build/static/js/8398.8bca8e0e.chunk.js +2 -0
  219. rapidfireai/frontend/build/static/js/8398.8bca8e0e.chunk.js.LICENSE.txt +6 -0
  220. rapidfireai/frontend/build/static/js/847.33ceed50.chunk.js +2 -0
  221. rapidfireai/frontend/build/static/js/847.33ceed50.chunk.js.LICENSE.txt +6 -0
  222. rapidfireai/frontend/build/static/js/8486.8ec852a7.chunk.js +1 -0
  223. rapidfireai/frontend/build/static/js/8497.19378265.chunk.js +1 -0
  224. rapidfireai/frontend/build/static/js/8541.4c55c9f4.chunk.js +1 -0
  225. rapidfireai/frontend/build/static/js/8690.e305a804.chunk.js +2 -0
  226. rapidfireai/frontend/build/static/js/8690.e305a804.chunk.js.LICENSE.txt +6 -0
  227. rapidfireai/frontend/build/static/js/8712.a9445fe6.chunk.js +1 -0
  228. rapidfireai/frontend/build/static/js/8763.61761e08.js +1 -0
  229. rapidfireai/frontend/build/static/js/8823.baf9bffd.chunk.js +2 -0
  230. rapidfireai/frontend/build/static/js/8823.baf9bffd.chunk.js.LICENSE.txt +6 -0
  231. rapidfireai/frontend/build/static/js/8867.767462b7.chunk.js +1 -0
  232. rapidfireai/frontend/build/static/js/8953.c0f88dea.chunk.js +1 -0
  233. rapidfireai/frontend/build/static/js/8960.357cb1eb.chunk.js +2 -0
  234. rapidfireai/frontend/build/static/js/8960.357cb1eb.chunk.js.LICENSE.txt +6 -0
  235. rapidfireai/frontend/build/static/js/9.f4492795.chunk.js +2 -0
  236. rapidfireai/frontend/build/static/js/9.f4492795.chunk.js.LICENSE.txt +12 -0
  237. rapidfireai/frontend/build/static/js/9079.88a8d2a3.js +1 -0
  238. rapidfireai/frontend/build/static/js/9082.37c40520.chunk.js +10 -0
  239. rapidfireai/frontend/build/static/js/9133.90ae330d.js +2 -0
  240. rapidfireai/frontend/build/static/js/9133.90ae330d.js.LICENSE.txt +8 -0
  241. rapidfireai/frontend/build/static/js/9151.1ac359d5.js +2 -0
  242. rapidfireai/frontend/build/static/js/9151.1ac359d5.js.LICENSE.txt +8 -0
  243. rapidfireai/frontend/build/static/js/9168.027bf2fd.chunk.js +1 -0
  244. rapidfireai/frontend/build/static/js/9194.9c5cc548.chunk.js +10 -0
  245. rapidfireai/frontend/build/static/js/9244.026f4aee.chunk.js +1 -0
  246. rapidfireai/frontend/build/static/js/936.2e02d037.js +2 -0
  247. rapidfireai/frontend/build/static/js/936.2e02d037.js.LICENSE.txt +6 -0
  248. rapidfireai/frontend/build/static/js/9369.7d1a0a1d.chunk.js +1 -0
  249. rapidfireai/frontend/build/static/js/9427.7c8442e7.chunk.js +1 -0
  250. rapidfireai/frontend/build/static/js/944.55948859.chunk.js +1 -0
  251. rapidfireai/frontend/build/static/js/9499.c53a82da.js +2 -0
  252. rapidfireai/frontend/build/static/js/9499.c53a82da.js.LICENSE.txt +62 -0
  253. rapidfireai/frontend/build/static/js/9531.3ce05781.chunk.js +1 -0
  254. rapidfireai/frontend/build/static/js/9547.92fac952.chunk.js +2 -0
  255. rapidfireai/frontend/build/static/js/9547.92fac952.chunk.js.LICENSE.txt +6 -0
  256. rapidfireai/frontend/build/static/js/9620.b6e973a7.chunk.js +1 -0
  257. rapidfireai/frontend/build/static/js/9645.6fddfa65.chunk.js +1 -0
  258. rapidfireai/frontend/build/static/js/9669.d38dda6d.js +1 -0
  259. rapidfireai/frontend/build/static/js/9682.41b6b807.chunk.js +1 -0
  260. rapidfireai/frontend/build/static/js/9720.19d5ae76.chunk.js +2 -0
  261. rapidfireai/frontend/build/static/js/9720.19d5ae76.chunk.js.LICENSE.txt +23 -0
  262. rapidfireai/frontend/build/static/js/9723.d3c7fe9e.js +1 -0
  263. rapidfireai/frontend/build/static/js/9780.02a27630.chunk.js +10 -0
  264. rapidfireai/frontend/build/static/js/9808.d0ca9674.chunk.js +2 -0
  265. rapidfireai/frontend/build/static/js/9808.d0ca9674.chunk.js.LICENSE.txt +6 -0
  266. rapidfireai/frontend/build/static/js/9815.b8db3c5d.js +1 -0
  267. rapidfireai/frontend/build/static/js/9886.2940b53a.chunk.js +1 -0
  268. rapidfireai/frontend/build/static/js/main~1f912138.fa9d03b1.js +1 -0
  269. rapidfireai/frontend/build/static/js/main~43dd7041.2e00860d.js +1 -0
  270. rapidfireai/frontend/build/static/js/main~84781932.68deffff.js +1 -0
  271. rapidfireai/frontend/build/static/media/404-overflow.fad9a31861b0afba6f921ebb8e769688.svg +32 -0
  272. rapidfireai/frontend/build/static/media/RapidFire_Square_Bug.27ceb48296314a4bc0d4.png +0 -0
  273. rapidfireai/frontend/build/static/media/chart-bar.0fd4a63680fba840a7b69fbf07969f79.svg +7 -0
  274. rapidfireai/frontend/build/static/media/chart-contour.0d4b306f2669f3ad25375568935e3ce3.svg +5 -0
  275. rapidfireai/frontend/build/static/media/chart-difference.16174216d6f3b7c24f40e3541fe0ca2c.svg +20 -0
  276. rapidfireai/frontend/build/static/media/chart-image.cc434c4dc50780966344e2385a15f8fe.svg +6 -0
  277. rapidfireai/frontend/build/static/media/chart-line.0adaa2036bb4eb5956db6d0c7e925a3d.svg +4 -0
  278. rapidfireai/frontend/build/static/media/chart-parallel.da7dedf539b2af4b654d377c679173e4.svg +7 -0
  279. rapidfireai/frontend/build/static/media/chart-scatter.69118d0023a6ff3973f7fa913834ac47.svg +9 -0
  280. rapidfireai/frontend/build/static/media/default-error.f246ddf367c6fbd67942e5a13382a7f1.svg +26 -0
  281. rapidfireai/frontend/build/static/media/fontawesome-webfont.1e59d2330b4c6deb84b3.ttf +0 -0
  282. rapidfireai/frontend/build/static/media/fontawesome-webfont.20fd1704ea223900efa9.woff2 +0 -0
  283. rapidfireai/frontend/build/static/media/fontawesome-webfont.8b43027f47b20503057d.eot +0 -0
  284. rapidfireai/frontend/build/static/media/fontawesome-webfont.c1e38fd9e0e74ba58f7a.svg +2671 -0
  285. rapidfireai/frontend/build/static/media/fontawesome-webfont.f691f37e57f04c152e23.woff +0 -0
  286. rapidfireai/frontend/build/static/media/icon-visible-fill.8d34cd35303828fdfc15154f5536e63b.svg +7 -0
  287. rapidfireai/frontend/build/static/media/no-experiments.0e4f4a114ef73e7d81c09474aba64b6c.svg +22 -0
  288. rapidfireai/frontend/build/static/media/parallel-chart-placeholder.234ef0c5b220ef2a5a6fa5bafff173f7.svg +16 -0
  289. rapidfireai/frontend/build/static/media/permission-denied-lock.16036747d57cd663d7df223781a447b2.svg +14 -0
  290. rapidfireai/frontend/build/static/media/promo-modal-content.e3b2c6c568ac192b9bec54b838b54850.svg +30 -0
  291. rapidfireai/frontend/build/static/media/registered-model-grey-ok.8274b58d39504c8d1b8c358aa1c9aa35.svg +23 -0
  292. rapidfireai/frontend/build/static/media/warning.290a3b14118933547965e91ea61c5a61.svg +3 -0
  293. rapidfireai/frontend/proxy_middleware.py +233 -0
  294. rapidfireai/frontend/server.py +25 -0
  295. rapidfireai/ml/__init__.py +0 -0
  296. rapidfireai/ml/callbacks.py +176 -0
  297. rapidfireai/ml/checkpoint_utils.py +540 -0
  298. rapidfireai/ml/trainer.py +309 -0
  299. rapidfireai/start.sh +634 -0
  300. rapidfireai/utils/__init__.py +0 -0
  301. rapidfireai/utils/automl_utils.py +51 -0
  302. rapidfireai/utils/constants.py +141 -0
  303. rapidfireai/utils/datapaths.py +69 -0
  304. rapidfireai/utils/exceptions.py +82 -0
  305. rapidfireai/utils/experiment_utils.py +370 -0
  306. rapidfireai/utils/logging.py +87 -0
  307. rapidfireai/utils/mlflow_manager.py +121 -0
  308. rapidfireai/utils/serialize.py +15 -0
  309. rapidfireai/utils/shm_manager.py +469 -0
  310. rapidfireai/utils/trainer_config.py +23 -0
  311. rapidfireai/utils/worker_manager.py +219 -0
  312. rapidfireai/version.py +6 -0
  313. rapidfireai-0.9.10.dist-info/METADATA +247 -0
  314. rapidfireai-0.9.10.dist-info/RECORD +318 -0
  315. rapidfireai-0.9.10.dist-info/entry_points.txt +2 -0
  316. rapidfireai-0.0.1.dist-info/METADATA +0 -37
  317. rapidfireai-0.0.1.dist-info/RECORD +0 -6
  318. {rapidfireai-0.0.1.dist-info → rapidfireai-0.9.10.dist-info}/WHEEL +0 -0
  319. {rapidfireai-0.0.1.dist-info → rapidfireai-0.9.10.dist-info}/licenses/LICENSE +0 -0
  320. {rapidfireai-0.0.1.dist-info → rapidfireai-0.9.10.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,469 @@
1
+ import copy
2
+ import gc
3
+ import threading
4
+ from multiprocessing import Lock, Manager
5
+
6
+ import torch
7
+
8
+ from rapidfireai.utils.constants import SHM_MIN_FREE_SPACE, SHM_WARN_THRESHOLD, SHMObjectType
9
+ from rapidfireai.utils.exceptions import InsufficientSharedMemoryException
10
+ from rapidfireai.utils.logging import RFLogger
11
+
12
+
13
+ def _get_shm_usage():
14
+ """Get shared memory storage usage information in GiB."""
15
+ import shutil
16
+
17
+ stat = shutil.disk_usage("/dev/shm")
18
+ total_gib = stat.total / (1024**3)
19
+ used_gib = stat.used / (1024**3)
20
+ free_gib = stat.free / (1024**3)
21
+ return {"total": total_gib, "used": used_gib, "free": free_gib, "percent_used": (stat.used / stat.total) * 100}
22
+
23
+
24
+ def _estimate_tensor_size_gib(obj):
25
+ """Recursively estimate the size of tensors in a nested structure in GiB."""
26
+ if isinstance(obj, torch.Tensor):
27
+ # Calculate size in bytes: numel * element_size
28
+ size_bytes = obj.numel() * obj.element_size()
29
+ return size_bytes / (1024**3)
30
+ elif isinstance(obj, dict):
31
+ return sum(_estimate_tensor_size_gib(v) for v in obj.values())
32
+ elif isinstance(obj, (list, tuple)):
33
+ return sum(_estimate_tensor_size_gib(item) for item in obj)
34
+ else:
35
+ return 0.0
36
+
37
+
38
+ def _verify_sufficient_model_size(model: torch.nn.Module | None, logger: RFLogger):
39
+ # Check available storage space in /dev/shm
40
+ shm_info = None
41
+ model_size_gib = 0.0
42
+ try:
43
+ shm_info = _get_shm_usage()
44
+ free_gib = shm_info["free"]
45
+ total_gib = shm_info["total"]
46
+ percent_used = shm_info["percent_used"]
47
+
48
+ # Estimate the size of the model to be saved
49
+ model_size_gib = 0.0
50
+ if model is not None and not isinstance(model, str):
51
+ # Estimate parameters size
52
+ for param in model.parameters():
53
+ if param.data is not None:
54
+ model_size_gib += _estimate_tensor_size_gib(param.data)
55
+
56
+ # Estimate buffers size
57
+ for buffer in model.buffers():
58
+ if buffer is not None:
59
+ model_size_gib += _estimate_tensor_size_gib(buffer)
60
+ except Exception as e:
61
+ logger.warning(f"Could not check shared memory space at /dev/shm: {e}")
62
+
63
+ if shm_info and model_size_gib > 0.0:
64
+ # Warn if usage is high
65
+ if percent_used > SHM_WARN_THRESHOLD:
66
+ logger.warning(
67
+ f"Shared memory usage is high: {percent_used:.1f}%. Available space: {free_gib:.2f}/{total_gib:.2f} GiB"
68
+ )
69
+
70
+ # Check if at least SHM_MIN_FREE_SPACE GiB will be left after saving the model
71
+ if free_gib - model_size_gib < SHM_MIN_FREE_SPACE:
72
+ raise InsufficientSharedMemoryException(
73
+ f"Insufficient shared memory space: {free_gib:.2f} GiB available, model size: "
74
+ f"{model_size_gib:.2f} GiB, need at least {SHM_MIN_FREE_SPACE} GiB remaining after save"
75
+ )
76
+
77
+ return True
78
+
79
+
80
+ def _verify_sufficient_ref_state_dict_size(ref_state_dict: dict, logger: RFLogger):
81
+ # Check available storage space in /dev/shm
82
+ shm_info = None
83
+ state_dict_size_gib = 0.0
84
+ try:
85
+ shm_info = _get_shm_usage()
86
+ free_gib = shm_info["free"]
87
+ total_gib = shm_info["total"]
88
+ percent_used = shm_info["percent_used"]
89
+
90
+ # Estimate the size of the state dict to be saved
91
+ state_dict_size_gib = _estimate_tensor_size_gib(ref_state_dict)
92
+
93
+ except Exception as e:
94
+ logger.warning(f"Could not check shared memory space at /dev/shm: {e}")
95
+
96
+ if shm_info and state_dict_size_gib > 0.0:
97
+ # Warn if usage is high
98
+ if percent_used > SHM_WARN_THRESHOLD:
99
+ logger.warning(
100
+ f"Shared memory usage is high: {percent_used:.1f}%. Available space: {free_gib:.2f}/{total_gib:.2f} GiB"
101
+ )
102
+
103
+ # Check if at least SHM_MIN_FREE_SPACE GiB will be left after saving the state dict
104
+ if free_gib - state_dict_size_gib < SHM_MIN_FREE_SPACE:
105
+ raise InsufficientSharedMemoryException(
106
+ f"Insufficient shared memory space: {free_gib:.2f} GiB available, state dict size: "
107
+ f"{state_dict_size_gib:.2f} GiB, need at least {SHM_MIN_FREE_SPACE} GiB remaining after save"
108
+ )
109
+
110
+ return True
111
+
112
+
113
+ class SharedMemoryManager:
114
+ """Manages PyTorch models and checkpoints in shared memory across multiple processes."""
115
+
116
+ def __init__(self, name: str, registry=None, multiprocess_lock=None):
117
+ """Initialize the shared memory manager with process-safe registry and locks"""
118
+ # initialize registry
119
+ if registry is None:
120
+ self._manager = Manager()
121
+ self._registry = self._manager.dict()
122
+ else:
123
+ self._registry = registry
124
+
125
+ # initialize multiprocess lock
126
+ if multiprocess_lock is None:
127
+ self._process_lock = self._manager.Lock()
128
+ else:
129
+ self._process_lock = multiprocess_lock
130
+
131
+ # initialize thread lock for operations within a single process
132
+ self._thread_lock = threading.Lock()
133
+
134
+ self.logger = RFLogger().create_logger(name)
135
+
136
+ # shared memory operations
137
+ def _safe_tensor_to_shared_memory(self, tensor: torch.Tensor | None) -> torch.Tensor | None:
138
+ """Safely convert a tensor to shared memory format"""
139
+ if tensor is None:
140
+ return None
141
+ tensor = tensor.cpu()
142
+ tensor = tensor.detach().contiguous().clone()
143
+ tensor.share_memory_()
144
+
145
+ return tensor
146
+
147
+ def _move_tensors_to_shared_memory(self, obj):
148
+ """Recursively move all tensors in a nested structure to shared memory"""
149
+ if isinstance(obj, torch.Tensor):
150
+ obj.share_memory_()
151
+ return obj
152
+ elif isinstance(obj, dict):
153
+ return {k: self._move_tensors_to_shared_memory(v) for k, v in obj.items()}
154
+ elif isinstance(obj, (list, tuple)):
155
+ return type(obj)(self._move_tensors_to_shared_memory(item) for item in obj)
156
+ else:
157
+ return obj
158
+
159
+ def _move_model_to_shared_memory(self, model):
160
+ """Move model to shared memory with proper BitsAndBytes handling"""
161
+ model = model.cpu()
162
+ for _, param in model.named_parameters():
163
+ if param.data is not None:
164
+ param.data = self._safe_tensor_to_shared_memory(param.data)
165
+
166
+ for name, buffer in model.named_buffers():
167
+ if isinstance(buffer, torch.Tensor) and buffer is not None:
168
+ parent_module = model
169
+ attr_path = name.split(".")
170
+
171
+ for attr in attr_path[:-1]:
172
+ parent_module = getattr(parent_module, attr)
173
+
174
+ shared_buffer = self._safe_tensor_to_shared_memory(buffer)
175
+ setattr(parent_module, attr_path[-1], shared_buffer)
176
+
177
+ bnb_modules = {}
178
+
179
+ for name, module in model.named_modules():
180
+ if not hasattr(module, "weight"):
181
+ continue
182
+
183
+ import bitsandbytes as bnb
184
+
185
+ bnb_layer_types = [bnb.nn.Linear4bit, bnb.nn.LinearFP4, bnb.nn.LinearNF4, bnb.nn.Params4bit]
186
+
187
+ is_bnb_layer = any(isinstance(module, layer_type) for layer_type in bnb_layer_types)
188
+
189
+ if is_bnb_layer and hasattr(module, "weight"):
190
+ bnb_attrs = {}
191
+ weight = module.weight
192
+
193
+ if hasattr(weight, "data") and weight.data is not None:
194
+ weight.data = self._safe_tensor_to_shared_memory(weight.data)
195
+
196
+ if hasattr(weight, "quant_state") and weight.quant_state is not None:
197
+ quant_state = weight.quant_state
198
+ bnb_attrs["quant_state_data"] = {}
199
+
200
+ for attr_name in dir(quant_state):
201
+ if not attr_name.startswith("_") and hasattr(quant_state, attr_name):
202
+ attr_val = getattr(quant_state, attr_name)
203
+
204
+ if isinstance(attr_val, torch.Tensor):
205
+ bnb_attrs["quant_state_data"][attr_name] = self._safe_tensor_to_shared_memory(attr_val)
206
+ elif not callable(attr_val):
207
+ bnb_attrs["quant_state_data"][attr_name] = attr_val
208
+
209
+ if hasattr(quant_state, "state2") and quant_state.state2 is not None:
210
+ state2 = quant_state.state2
211
+ bnb_attrs["state2_data"] = {}
212
+ for attr_name in dir(state2):
213
+ if not attr_name.startswith("_") and hasattr(state2, attr_name):
214
+ attr_val = getattr(state2, attr_name)
215
+ if isinstance(attr_val, torch.Tensor):
216
+ bnb_attrs["state2_data"][attr_name] = self._safe_tensor_to_shared_memory(attr_val)
217
+ elif not callable(attr_val):
218
+ bnb_attrs["state2_data"][attr_name] = attr_val
219
+
220
+ bnb_attrs["quant_state_class"] = type(quant_state).__name__
221
+
222
+ weight_attrs = ["compress_statistics", "quant_type", "blocksize", "bnb_quantized"]
223
+ for attr in weight_attrs:
224
+ if hasattr(weight, attr):
225
+ attr_val = getattr(weight, attr)
226
+ if isinstance(attr_val, torch.Tensor):
227
+ attr_val = self._safe_tensor_to_shared_memory(attr_val)
228
+ bnb_attrs[attr] = attr_val
229
+
230
+ bnb_attrs["weight_class"] = type(weight).__name__
231
+ bnb_modules[name] = bnb_attrs
232
+
233
+ return model, bnb_modules
234
+
235
+ # model object operations
236
+ def _save_full_model(self, model_id: str, model_data: dict, model_object_type: SHMObjectType):
237
+ """Save the full model in shared memory. model_id can be either run_id or name of a base model"""
238
+ with self._process_lock if self._process_lock else self._thread_lock:
239
+ if model_id in self._registry:
240
+ self.logger.debug(f"Model {model_id} already exists in shared memory. Skipping save.")
241
+ return
242
+
243
+ # verify sufficient shared memory space before saving model
244
+ _verify_sufficient_model_size(model_data[model_object_type], self.logger)
245
+
246
+ # create model entry in registry
247
+ if model_id not in self._registry:
248
+ self._registry[model_id] = {model_object_type: {}}
249
+
250
+ # move model to shared memory
251
+ model_cpu = model_data[model_object_type]
252
+ tokenizer = model_data["tokenizer"]
253
+ model, bnb_modules = self._move_model_to_shared_memory(model_cpu)
254
+ shared_model = {
255
+ model_object_type: model,
256
+ "tokenizer": tokenizer,
257
+ "bnb_modules": self._move_tensors_to_shared_memory(bnb_modules),
258
+ }
259
+ model_entry = dict(self._registry[model_id])
260
+ model_entry[model_object_type] = shared_model
261
+ self._registry[model_id] = model_entry
262
+
263
+ self.logger.debug(f"Saved {model_object_type.value} for run {model_id}")
264
+
265
+ def _save_ref_state_dict(self, model_id: str, ref_state_dict: dict):
266
+ """Save the reference state dict."""
267
+ with self._thread_lock:
268
+ # verify sufficient shared memory space before saving ref_state_dict
269
+ _verify_sufficient_ref_state_dict_size(ref_state_dict, self.logger)
270
+
271
+ # create model entry in registry
272
+ if model_id not in self._registry:
273
+ self._registry[model_id] = {SHMObjectType.REF_STATE_DICT: {}}
274
+
275
+ # move ref_state_dict to shared memory
276
+ shared_ref_state_dict = self._move_tensors_to_shared_memory(ref_state_dict)
277
+ model_entry = dict(self._registry[model_id])
278
+ model_entry[SHMObjectType.REF_STATE_DICT] = shared_ref_state_dict
279
+ self._registry[model_id] = model_entry
280
+
281
+ self.logger.debug(f"Saved ref_state_dict for {model_id}")
282
+
283
+ def _update_checkpoints(self, model_id: str, checkpoint_updates: dict):
284
+ """Update checkpoints in-place when possible, add new keys when needed."""
285
+ with self._thread_lock:
286
+ # create model entry in registry
287
+ if model_id not in self._registry:
288
+ self._registry[model_id] = {SHMObjectType.CHECKPOINTS: {}}
289
+
290
+ model_entry = self._registry[model_id]
291
+ if SHMObjectType.CHECKPOINTS not in model_entry:
292
+ model_entry[SHMObjectType.CHECKPOINTS] = {}
293
+ current_checkpoints = model_entry[SHMObjectType.CHECKPOINTS]
294
+
295
+ updates_made = {"in_place": 0, "new_keys": 0}
296
+
297
+ def update_nested_dict(current_dict, updates_dict, path=""):
298
+ for key, new_value in updates_dict.items():
299
+ current_path = f"{path}.{key}" if path else key
300
+
301
+ if key in current_dict:
302
+ current_value = current_dict[key]
303
+
304
+ if isinstance(new_value, torch.Tensor) and isinstance(current_value, torch.Tensor):
305
+ # In-place tensor update if shapes match
306
+ if (
307
+ current_value.shape == new_value.shape
308
+ and current_value.dtype == new_value.dtype
309
+ and current_value.is_shared()
310
+ ):
311
+ current_value.copy_(new_value.cpu())
312
+ updates_made["in_place"] += 1
313
+ else:
314
+ # Need new shared tensor
315
+ new_shared = new_value.cpu().clone()
316
+ new_shared.share_memory_()
317
+ current_dict[key] = new_shared
318
+ updates_made["new_keys"] += 1
319
+ self.logger.debug(f"New tensor (shape/type change): {current_path}")
320
+
321
+ elif isinstance(new_value, dict) and isinstance(current_value, dict):
322
+ # Recursively update nested dicts
323
+ update_nested_dict(current_value, new_value, current_path)
324
+
325
+ else:
326
+ # Non-tensor value update
327
+ current_dict[key] = new_value
328
+
329
+ else:
330
+ # New key - add to shared memory
331
+ if isinstance(new_value, torch.Tensor):
332
+ new_shared = new_value.cpu().clone()
333
+ new_shared.share_memory_()
334
+ current_dict[key] = new_shared
335
+ updates_made["new_keys"] += 1
336
+ elif isinstance(new_value, dict):
337
+ # New nested dict
338
+ current_dict[key] = self._move_tensors_to_shared_memory(new_value)
339
+ updates_made["new_keys"] += 1
340
+ else:
341
+ # New non-tensor value
342
+ current_dict[key] = new_value
343
+
344
+ # Update the checkpoints
345
+ update_nested_dict(current_checkpoints, checkpoint_updates)
346
+
347
+ # Update the registry entry to ensure Manager sees changes
348
+ updated_entry = dict(model_entry)
349
+ updated_entry[SHMObjectType.CHECKPOINTS] = current_checkpoints
350
+ self._registry[model_id] = updated_entry
351
+
352
+ self.logger.debug(f"Checkpoint update:{updates_made['in_place']} in-place, {updates_made['new_keys']} new")
353
+
354
+ def get_shm_objects(self) -> tuple[dict, Lock]:
355
+ """Get the shared registry and process lock"""
356
+ return self._registry, self._process_lock
357
+
358
+ def load_model_object(self, model_id: str, model_object_type: SHMObjectType):
359
+ """Load a model object from shared memory."""
360
+ model_entry = self._registry.get(model_id)
361
+ if model_entry is None:
362
+ self.logger.warning(f"Model {model_id} not found in shared memory")
363
+ return None
364
+ model_obj = model_entry.get(model_object_type)
365
+ return model_obj
366
+
367
+ def save_model_object(self, model_id: str, model_object_type: SHMObjectType, model_object: dict):
368
+ """Save a model object to shared memory."""
369
+ # save model object
370
+ if model_object_type in [SHMObjectType.BASE_MODEL, SHMObjectType.FULL_MODEL, SHMObjectType.REF_FULL_MODEL]:
371
+ self._save_full_model(model_id, model_object, model_object_type)
372
+ elif model_object_type == SHMObjectType.REF_STATE_DICT:
373
+ self._save_ref_state_dict(model_id, model_object)
374
+ elif model_object_type == SHMObjectType.CHECKPOINTS:
375
+ self._update_checkpoints(model_id, model_object)
376
+
377
+ def delete_model_object(self, model_id: str, base_model_name: str | None = None):
378
+ """Delete model object from shared memory registry and clean up resources."""
379
+ with self._process_lock if self._process_lock else self._thread_lock:
380
+ if model_id not in self._registry:
381
+ self.logger.warning(f"Model '{model_id}' not found in shared memory during delete")
382
+ return
383
+
384
+ # remove checkpoints
385
+ # TODO: add code to save to disk before deleting
386
+ if (
387
+ SHMObjectType.CHECKPOINTS in self._registry[model_id]
388
+ and self._registry[model_id][SHMObjectType.CHECKPOINTS]
389
+ ):
390
+ del self._registry[model_id][SHMObjectType.CHECKPOINTS]
391
+ self.logger.debug(f"Deleted checkpoints for model {model_id} from shared memory")
392
+
393
+ # remove full_model
394
+ # TODO: add code to save to disk before deleting
395
+ if (
396
+ SHMObjectType.FULL_MODEL in self._registry[model_id]
397
+ and self._registry[model_id][SHMObjectType.FULL_MODEL]
398
+ ):
399
+ del self._registry[model_id][SHMObjectType.FULL_MODEL]
400
+ self.logger.debug(f"Deleted full_model for model {model_id} from shared memory")
401
+
402
+ # remove ref_state_dict
403
+ if (
404
+ SHMObjectType.REF_STATE_DICT in self._registry[model_id]
405
+ and self._registry[model_id][SHMObjectType.REF_STATE_DICT]
406
+ ):
407
+ del self._registry[model_id][SHMObjectType.REF_STATE_DICT]
408
+ self.logger.debug(f"Deleted ref_state_dict for model {model_id} from shared memory")
409
+
410
+ # remove ref_full_model
411
+ if (
412
+ SHMObjectType.REF_FULL_MODEL in self._registry[model_id]
413
+ and self._registry[model_id][SHMObjectType.REF_FULL_MODEL]
414
+ ):
415
+ del self._registry[model_id][SHMObjectType.REF_FULL_MODEL]
416
+ self.logger.debug(f"Deleted ref_full_model for model {model_id} from shared memory")
417
+
418
+ # remove shared objects (entire registry entry is deleted for base_model, not just SHMObjectType.BASE_MODEL key)
419
+ if base_model_name and base_model_name in self._registry:
420
+ del self._registry[base_model_name]
421
+ self.logger.debug(f"Deleted base_model for model {model_id} from shared memory")
422
+
423
+ # remove registry entry
424
+ del self._registry[model_id]
425
+ self.logger.debug(f"Deleted model registry entry for {model_id} from shared memory")
426
+
427
+ # Force garbage collection
428
+ gc.collect()
429
+ if torch.cuda.is_available():
430
+ torch.cuda.empty_cache()
431
+
432
+ self.logger.debug("Force garbage collection and empty cache")
433
+
434
+ def create_warm_start_checkpoint(self, model_id: str, warm_started_from: str):
435
+ """Copy warm start checkpoint from model_id to warm_started_from"""
436
+ with self._thread_lock:
437
+ if warm_started_from not in self._registry:
438
+ raise KeyError(f"Run '{warm_started_from}' not found in shared memory")
439
+
440
+ # create model entry in registry
441
+ if model_id not in self._registry:
442
+ self._registry[model_id] = {
443
+ SHMObjectType.FULL_MODEL: {},
444
+ SHMObjectType.REF_STATE_DICT: {},
445
+ SHMObjectType.CHECKPOINTS: {},
446
+ }
447
+
448
+ model_entry = dict(self._registry[model_id])
449
+ model_entry[SHMObjectType.FULL_MODEL] = copy.deepcopy(
450
+ dict(self._registry[warm_started_from])[SHMObjectType.FULL_MODEL]
451
+ )
452
+ model_entry[SHMObjectType.REF_STATE_DICT] = copy.deepcopy(
453
+ dict(self._registry[warm_started_from])[SHMObjectType.REF_STATE_DICT]
454
+ )
455
+ model_entry[SHMObjectType.CHECKPOINTS] = copy.deepcopy(
456
+ dict(self._registry[warm_started_from])[SHMObjectType.CHECKPOINTS]
457
+ )
458
+ self._registry[model_id] = model_entry
459
+ self.logger.debug(f"Copied warm start checkpoint from {warm_started_from} to {model_id}")
460
+
461
+ def list_models(self):
462
+ """Get list of all model IDs currently in shared memory."""
463
+ with self._process_lock if self._process_lock else self._thread_lock:
464
+ return list(self._registry.keys())
465
+
466
+ def model_exists(self, model_id: str):
467
+ """Check if a model exists in shared memory."""
468
+ with self._process_lock if self._process_lock else self._thread_lock:
469
+ return model_id in self._registry
@@ -0,0 +1,23 @@
1
+ """This module contains the TrainerConfig class which is responsible for configuring the trainer."""
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any, Callable, Optional
5
+
6
+ import torch
7
+
8
+
9
+ @dataclass
10
+ class TrainerConfig:
11
+ """Trainer configuration"""
12
+
13
+ worker_id: int
14
+ run_id: int
15
+ mlflow_run_id: str
16
+ config_leaf: dict[str, Any]
17
+ total_steps: int
18
+ completed_steps: int
19
+ create_model_fn: Callable
20
+ train_dataset: torch.utils.data.Dataset
21
+ eval_dataset: Optional[torch.utils.data.Dataset]
22
+ warm_started_from: int | None
23
+ num_epochs_completed: int