rapidfireai 0.0.1__py3-none-any.whl → 0.9.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rapidfireai might be problematic. Click here for more details.

Files changed (320) hide show
  1. rapidfireai/__init__.py +11 -5
  2. rapidfireai/automl/__init__.py +20 -0
  3. rapidfireai/automl/base.py +48 -0
  4. rapidfireai/automl/datatypes.py +42 -0
  5. rapidfireai/automl/grid_search.py +125 -0
  6. rapidfireai/automl/model_config.py +102 -0
  7. rapidfireai/automl/random_search.py +145 -0
  8. rapidfireai/backend/__init__.py +0 -0
  9. rapidfireai/backend/chunks.py +63 -0
  10. rapidfireai/backend/controller.py +637 -0
  11. rapidfireai/backend/scheduler.py +137 -0
  12. rapidfireai/backend/worker.py +272 -0
  13. rapidfireai/cli.py +380 -0
  14. rapidfireai/db/__init__.py +0 -0
  15. rapidfireai/db/db_interface.py +135 -0
  16. rapidfireai/db/rf_db.py +694 -0
  17. rapidfireai/db/tables.sql +64 -0
  18. rapidfireai/dispatcher/dispatcher.py +391 -0
  19. rapidfireai/dispatcher/gunicorn.conf.py +25 -0
  20. rapidfireai/experiment.py +168 -0
  21. rapidfireai/frontend/build/asset-manifest.json +276 -0
  22. rapidfireai/frontend/build/favicon.ico +0 -0
  23. rapidfireai/frontend/build/index.html +1 -0
  24. rapidfireai/frontend/build/manifest.json +15 -0
  25. rapidfireai/frontend/build/pdf.worker.js +1 -0
  26. rapidfireai/frontend/build/report.html +39 -0
  27. rapidfireai/frontend/build/static/css/1482.3b7bf531.chunk.css +1 -0
  28. rapidfireai/frontend/build/static/css/2730.3f8937ff.chunk.css +1 -0
  29. rapidfireai/frontend/build/static/css/318.0def90a7.css +7 -0
  30. rapidfireai/frontend/build/static/css/4762.9b7b71f7.chunk.css +1 -0
  31. rapidfireai/frontend/build/static/css/4950.487ecc8b.chunk.css +1 -0
  32. rapidfireai/frontend/build/static/css/5170.2574ce9d.chunk.css +1 -0
  33. rapidfireai/frontend/build/static/css/6121.4d541986.chunk.css +1 -0
  34. rapidfireai/frontend/build/static/css/6343.dd6979f2.chunk.css +1 -0
  35. rapidfireai/frontend/build/static/css/6534.433c213f.chunk.css +1 -0
  36. rapidfireai/frontend/build/static/css/6920.ffac4b2a.css +2 -0
  37. rapidfireai/frontend/build/static/css/7246.bf2f0c87.css +9 -0
  38. rapidfireai/frontend/build/static/css/7367.dd6979f2.chunk.css +1 -0
  39. rapidfireai/frontend/build/static/css/8690.05d081e5.chunk.css +1 -0
  40. rapidfireai/frontend/build/static/css/9531.d0910d3c.chunk.css +1 -0
  41. rapidfireai/frontend/build/static/css/9780.363e4943.chunk.css +1 -0
  42. rapidfireai/frontend/build/static/css/main~d91a9049.c0be472c.css +1 -0
  43. rapidfireai/frontend/build/static/js/1000.e5ed264b.chunk.js +1 -0
  44. rapidfireai/frontend/build/static/js/1012.ac98ab59.chunk.js +1 -0
  45. rapidfireai/frontend/build/static/js/1079.6c13ac0d.js +1 -0
  46. rapidfireai/frontend/build/static/js/110.9059f3b8.chunk.js +1 -0
  47. rapidfireai/frontend/build/static/js/1142.872d0010.chunk.js +1 -0
  48. rapidfireai/frontend/build/static/js/1167.9a6da14c.chunk.js +1 -0
  49. rapidfireai/frontend/build/static/js/1248.60890b4f.chunk.js +1 -0
  50. rapidfireai/frontend/build/static/js/1262.83dc7673.chunk.js +1 -0
  51. rapidfireai/frontend/build/static/js/1273.56da3e13.chunk.js +2 -0
  52. rapidfireai/frontend/build/static/js/1273.56da3e13.chunk.js.LICENSE.txt +9 -0
  53. rapidfireai/frontend/build/static/js/1303.7d19305c.chunk.js +1 -0
  54. rapidfireai/frontend/build/static/js/1351.45076ff3.chunk.js +1 -0
  55. rapidfireai/frontend/build/static/js/1355.b896a592.js +1 -0
  56. rapidfireai/frontend/build/static/js/1357.02c46a02.chunk.js +1 -0
  57. rapidfireai/frontend/build/static/js/1470.c51d60c6.chunk.js +1 -0
  58. rapidfireai/frontend/build/static/js/1482.23b74f50.chunk.js +1 -0
  59. rapidfireai/frontend/build/static/js/1500.19799d8d.chunk.js +1 -0
  60. rapidfireai/frontend/build/static/js/1648.d3b9edc7.chunk.js +1 -0
  61. rapidfireai/frontend/build/static/js/1860.7d96e3f9.chunk.js +1 -0
  62. rapidfireai/frontend/build/static/js/1909.5b1d9ff4.chunk.js +1 -0
  63. rapidfireai/frontend/build/static/js/1928.44245110.chunk.js +2 -0
  64. rapidfireai/frontend/build/static/js/1928.44245110.chunk.js.LICENSE.txt +11 -0
  65. rapidfireai/frontend/build/static/js/1933.deba26ca.chunk.js +1 -0
  66. rapidfireai/frontend/build/static/js/21.aac92802.chunk.js +1 -0
  67. rapidfireai/frontend/build/static/js/2103.0ca12071.chunk.js +1 -0
  68. rapidfireai/frontend/build/static/js/2258.b3b8fab4.chunk.js +1 -0
  69. rapidfireai/frontend/build/static/js/2289.9ad51e87.chunk.js +1 -0
  70. rapidfireai/frontend/build/static/js/2323.7dd927d7.js +2 -0
  71. rapidfireai/frontend/build/static/js/2323.7dd927d7.js.LICENSE.txt +1 -0
  72. rapidfireai/frontend/build/static/js/2346.ed99ca72.chunk.js +1 -0
  73. rapidfireai/frontend/build/static/js/2386.0a660834.chunk.js +1 -0
  74. rapidfireai/frontend/build/static/js/2402.465048f9.chunk.js +1 -0
  75. rapidfireai/frontend/build/static/js/243.5a83bbca.chunk.js +1 -0
  76. rapidfireai/frontend/build/static/js/2589.68571e16.js +1 -0
  77. rapidfireai/frontend/build/static/js/2647.65092bab.chunk.js +1 -0
  78. rapidfireai/frontend/build/static/js/2691.65d4a4e7.js +1 -0
  79. rapidfireai/frontend/build/static/js/2730.b38dd6f3.chunk.js +1 -0
  80. rapidfireai/frontend/build/static/js/2746.ef752da4.chunk.js +1 -0
  81. rapidfireai/frontend/build/static/js/2779.580d4491.chunk.js +1 -0
  82. rapidfireai/frontend/build/static/js/2799.fe5993b2.chunk.js +1 -0
  83. rapidfireai/frontend/build/static/js/2844.9708db79.chunk.js +2 -0
  84. rapidfireai/frontend/build/static/js/2844.9708db79.chunk.js.LICENSE.txt +21 -0
  85. rapidfireai/frontend/build/static/js/2901.ee0c606b.chunk.js +1 -0
  86. rapidfireai/frontend/build/static/js/2932.7cc0689b.chunk.js +2 -0
  87. rapidfireai/frontend/build/static/js/2932.7cc0689b.chunk.js.LICENSE.txt +6 -0
  88. rapidfireai/frontend/build/static/js/2956.a393c8cc.chunk.js +1 -0
  89. rapidfireai/frontend/build/static/js/2972.679bed05.chunk.js +1 -0
  90. rapidfireai/frontend/build/static/js/2985.7e51cdfa.chunk.js +2 -0
  91. rapidfireai/frontend/build/static/js/2985.7e51cdfa.chunk.js.LICENSE.txt +51 -0
  92. rapidfireai/frontend/build/static/js/3093.488df653.js +1 -0
  93. rapidfireai/frontend/build/static/js/3145.66ee61b9.js +1 -0
  94. rapidfireai/frontend/build/static/js/3170.a22f966a.chunk.js +2 -0
  95. rapidfireai/frontend/build/static/js/3170.a22f966a.chunk.js.LICENSE.txt +21 -0
  96. rapidfireai/frontend/build/static/js/3307.f6fb258c.chunk.js +1 -0
  97. rapidfireai/frontend/build/static/js/3325.d5b03d65.js +1 -0
  98. rapidfireai/frontend/build/static/js/3334.2d6704df.chunk.js +2 -0
  99. rapidfireai/frontend/build/static/js/3334.2d6704df.chunk.js.LICENSE.txt +6 -0
  100. rapidfireai/frontend/build/static/js/3387.bb8edad3.chunk.js +1 -0
  101. rapidfireai/frontend/build/static/js/3448.438e6579.chunk.js +1 -0
  102. rapidfireai/frontend/build/static/js/3460.735eea87.chunk.js +1 -0
  103. rapidfireai/frontend/build/static/js/3505.7fd3921a.js +2 -0
  104. rapidfireai/frontend/build/static/js/3505.7fd3921a.js.LICENSE.txt +9 -0
  105. rapidfireai/frontend/build/static/js/3510.cd167a00.js +2 -0
  106. rapidfireai/frontend/build/static/js/3510.cd167a00.js.LICENSE.txt +18 -0
  107. rapidfireai/frontend/build/static/js/3563.cc828e19.chunk.js +1 -0
  108. rapidfireai/frontend/build/static/js/359.08960b84.chunk.js +2 -0
  109. rapidfireai/frontend/build/static/js/359.08960b84.chunk.js.LICENSE.txt +4 -0
  110. rapidfireai/frontend/build/static/js/3608.403b4b79.chunk.js +1 -0
  111. rapidfireai/frontend/build/static/js/3652.cb8add7f.js +1 -0
  112. rapidfireai/frontend/build/static/js/3775.5230b157.chunk.js +1 -0
  113. rapidfireai/frontend/build/static/js/3817.53555d18.js +2 -0
  114. rapidfireai/frontend/build/static/js/3817.53555d18.js.LICENSE.txt +18 -0
  115. rapidfireai/frontend/build/static/js/3835.d9946ff9.chunk.js +1 -0
  116. rapidfireai/frontend/build/static/js/3964.874f0297.chunk.js +1 -0
  117. rapidfireai/frontend/build/static/js/3968.275cbc3d.chunk.js +1 -0
  118. rapidfireai/frontend/build/static/js/3999.765cbd82.chunk.js +1 -0
  119. rapidfireai/frontend/build/static/js/4020.4452c046.chunk.js +1 -0
  120. rapidfireai/frontend/build/static/js/4138.2f6f6d9f.js +1 -0
  121. rapidfireai/frontend/build/static/js/4160.f424554c.js +1 -0
  122. rapidfireai/frontend/build/static/js/4180.50cea095.chunk.js +1 -0
  123. rapidfireai/frontend/build/static/js/4221.b0bba3f5.chunk.js +1 -0
  124. rapidfireai/frontend/build/static/js/4250.5bb49278.chunk.js +1 -0
  125. rapidfireai/frontend/build/static/js/4297.15777d8f.chunk.js +1 -0
  126. rapidfireai/frontend/build/static/js/4349.c965f2de.js +2 -0
  127. rapidfireai/frontend/build/static/js/4349.c965f2de.js.LICENSE.txt +1 -0
  128. rapidfireai/frontend/build/static/js/4484.4cbe5e7f.js +2 -0
  129. rapidfireai/frontend/build/static/js/4484.4cbe5e7f.js.LICENSE.txt +10 -0
  130. rapidfireai/frontend/build/static/js/4578.a8124588.js +1 -0
  131. rapidfireai/frontend/build/static/js/4596.89a97480.js +1 -0
  132. rapidfireai/frontend/build/static/js/4748.566f435a.chunk.js +1 -0
  133. rapidfireai/frontend/build/static/js/4762.928e8a90.chunk.js +1 -0
  134. rapidfireai/frontend/build/static/js/4768.7945be63.js +2 -0
  135. rapidfireai/frontend/build/static/js/4768.7945be63.js.LICENSE.txt +1 -0
  136. rapidfireai/frontend/build/static/js/4804.26b50dd4.chunk.js +1 -0
  137. rapidfireai/frontend/build/static/js/4850.62390a45.chunk.js +1 -0
  138. rapidfireai/frontend/build/static/js/4862.a0ccb221.chunk.js +1 -0
  139. rapidfireai/frontend/build/static/js/491.5dc8ed40.chunk.js +1 -0
  140. rapidfireai/frontend/build/static/js/492.9262f038.chunk.js +2 -0
  141. rapidfireai/frontend/build/static/js/492.9262f038.chunk.js.LICENSE.txt +6 -0
  142. rapidfireai/frontend/build/static/js/4943.6d345fd3.chunk.js +1 -0
  143. rapidfireai/frontend/build/static/js/4950.bc182e62.chunk.js +1 -0
  144. rapidfireai/frontend/build/static/js/5042.d4f0c65a.chunk.js +2 -0
  145. rapidfireai/frontend/build/static/js/5042.d4f0c65a.chunk.js.LICENSE.txt +6 -0
  146. rapidfireai/frontend/build/static/js/5170.0065e96f.chunk.js +1 -0
  147. rapidfireai/frontend/build/static/js/5222.35c74a52.js +2 -0
  148. rapidfireai/frontend/build/static/js/5222.35c74a52.js.LICENSE.txt +10 -0
  149. rapidfireai/frontend/build/static/js/5223.3224f019.chunk.js +2 -0
  150. rapidfireai/frontend/build/static/js/5223.3224f019.chunk.js.LICENSE.txt +3 -0
  151. rapidfireai/frontend/build/static/js/5229.7dd42316.chunk.js +1 -0
  152. rapidfireai/frontend/build/static/js/5286.4c1ad26b.js +1 -0
  153. rapidfireai/frontend/build/static/js/5486.21cff711.chunk.js +1 -0
  154. rapidfireai/frontend/build/static/js/5526.7b368956.chunk.js +1 -0
  155. rapidfireai/frontend/build/static/js/5605.1ee4d87b.chunk.js +1 -0
  156. rapidfireai/frontend/build/static/js/5682.40b42d8b.chunk.js +1 -0
  157. rapidfireai/frontend/build/static/js/5794.9433d867.chunk.js +1 -0
  158. rapidfireai/frontend/build/static/js/5826.38a56e8c.chunk.js +2 -0
  159. rapidfireai/frontend/build/static/js/5826.38a56e8c.chunk.js.LICENSE.txt +1 -0
  160. rapidfireai/frontend/build/static/js/5862.50f42a0b.js +1 -0
  161. rapidfireai/frontend/build/static/js/5895.e26742f1.chunk.js +1 -0
  162. rapidfireai/frontend/build/static/js/5919.edd4a5cf.chunk.js +1 -0
  163. rapidfireai/frontend/build/static/js/598.a0e792ae.js +1 -0
  164. rapidfireai/frontend/build/static/js/6058.74162bf9.chunk.js +1 -0
  165. rapidfireai/frontend/build/static/js/618.06051134.chunk.js +2 -0
  166. rapidfireai/frontend/build/static/js/618.06051134.chunk.js.LICENSE.txt +21 -0
  167. rapidfireai/frontend/build/static/js/6335.9fca442d.chunk.js +1 -0
  168. rapidfireai/frontend/build/static/js/6336.e05e1154.chunk.js +1 -0
  169. rapidfireai/frontend/build/static/js/6343.2bcd28ff.chunk.js +1 -0
  170. rapidfireai/frontend/build/static/js/6363.a319b8f2.chunk.js +1 -0
  171. rapidfireai/frontend/build/static/js/6478.344abf25.chunk.js +1 -0
  172. rapidfireai/frontend/build/static/js/6504.1c004564.js +1 -0
  173. rapidfireai/frontend/build/static/js/6534.ec7e149b.chunk.js +1 -0
  174. rapidfireai/frontend/build/static/js/6715.55a5c19c.chunk.js +1 -0
  175. rapidfireai/frontend/build/static/js/6756.e6cb993c.chunk.js +2 -0
  176. rapidfireai/frontend/build/static/js/6756.e6cb993c.chunk.js.LICENSE.txt +10 -0
  177. rapidfireai/frontend/build/static/js/6762.acfde9fd.chunk.js +2 -0
  178. rapidfireai/frontend/build/static/js/6762.acfde9fd.chunk.js.LICENSE.txt +19 -0
  179. rapidfireai/frontend/build/static/js/6846.67103d0e.chunk.js +1 -0
  180. rapidfireai/frontend/build/static/js/6861.34cf0198.chunk.js +1 -0
  181. rapidfireai/frontend/build/static/js/6899.0eaf36a8.chunk.js +2 -0
  182. rapidfireai/frontend/build/static/js/6899.0eaf36a8.chunk.js.LICENSE.txt +5 -0
  183. rapidfireai/frontend/build/static/js/6933.8b564944.chunk.js +1 -0
  184. rapidfireai/frontend/build/static/js/699.d0437920.js +1 -0
  185. rapidfireai/frontend/build/static/js/7076.4182f63a.chunk.js +1 -0
  186. rapidfireai/frontend/build/static/js/7186.42ad86d5.chunk.js +1 -0
  187. rapidfireai/frontend/build/static/js/7248.a46635fd.js +1 -0
  188. rapidfireai/frontend/build/static/js/725.6b15a14a.chunk.js +1 -0
  189. rapidfireai/frontend/build/static/js/7266.3575539d.chunk.js +1 -0
  190. rapidfireai/frontend/build/static/js/7270.0a1e84fc.chunk.js +2 -0
  191. rapidfireai/frontend/build/static/js/7270.0a1e84fc.chunk.js.LICENSE.txt +6 -0
  192. rapidfireai/frontend/build/static/js/7367.7120474f.chunk.js +1 -0
  193. rapidfireai/frontend/build/static/js/7436.8e226055.js +1 -0
  194. rapidfireai/frontend/build/static/js/7504.ef223844.chunk.js +1 -0
  195. rapidfireai/frontend/build/static/js/7603.ee049fe3.chunk.js +1 -0
  196. rapidfireai/frontend/build/static/js/7670.2835b49a.chunk.js +2 -0
  197. rapidfireai/frontend/build/static/js/7670.2835b49a.chunk.js.LICENSE.txt +6 -0
  198. rapidfireai/frontend/build/static/js/7721.7390b3cc.chunk.js +1 -0
  199. rapidfireai/frontend/build/static/js/7731.5796cced.chunk.js +1 -0
  200. rapidfireai/frontend/build/static/js/775.660a5deb.chunk.js +2 -0
  201. rapidfireai/frontend/build/static/js/775.660a5deb.chunk.js.LICENSE.txt +6 -0
  202. rapidfireai/frontend/build/static/js/7832.7976a3e4.chunk.js +1 -0
  203. rapidfireai/frontend/build/static/js/7844.72cc2e81.chunk.js +1 -0
  204. rapidfireai/frontend/build/static/js/7948.48eab032.js +1 -0
  205. rapidfireai/frontend/build/static/js/7972.085079d4.chunk.js +2 -0
  206. rapidfireai/frontend/build/static/js/7972.085079d4.chunk.js.LICENSE.txt +6 -0
  207. rapidfireai/frontend/build/static/js/8017.a9e7dc5a.chunk.js +1 -0
  208. rapidfireai/frontend/build/static/js/8023.75f1f3df.js +2 -0
  209. rapidfireai/frontend/build/static/js/8023.75f1f3df.js.LICENSE.txt +41 -0
  210. rapidfireai/frontend/build/static/js/8123.b69db974.js +1 -0
  211. rapidfireai/frontend/build/static/js/813.065a87e5.chunk.js +1 -0
  212. rapidfireai/frontend/build/static/js/819.2056f122.chunk.js +2 -0
  213. rapidfireai/frontend/build/static/js/819.2056f122.chunk.js.LICENSE.txt +6 -0
  214. rapidfireai/frontend/build/static/js/8262.04bc17d1.chunk.js +1 -0
  215. rapidfireai/frontend/build/static/js/8300.75adcc4f.chunk.js +1 -0
  216. rapidfireai/frontend/build/static/js/8336.b1d3e764.chunk.js +1 -0
  217. rapidfireai/frontend/build/static/js/8365.26cf64ea.chunk.js +1 -0
  218. rapidfireai/frontend/build/static/js/8398.8bca8e0e.chunk.js +2 -0
  219. rapidfireai/frontend/build/static/js/8398.8bca8e0e.chunk.js.LICENSE.txt +6 -0
  220. rapidfireai/frontend/build/static/js/847.33ceed50.chunk.js +2 -0
  221. rapidfireai/frontend/build/static/js/847.33ceed50.chunk.js.LICENSE.txt +6 -0
  222. rapidfireai/frontend/build/static/js/8486.8ec852a7.chunk.js +1 -0
  223. rapidfireai/frontend/build/static/js/8497.19378265.chunk.js +1 -0
  224. rapidfireai/frontend/build/static/js/8541.4c55c9f4.chunk.js +1 -0
  225. rapidfireai/frontend/build/static/js/8690.e305a804.chunk.js +2 -0
  226. rapidfireai/frontend/build/static/js/8690.e305a804.chunk.js.LICENSE.txt +6 -0
  227. rapidfireai/frontend/build/static/js/8712.a9445fe6.chunk.js +1 -0
  228. rapidfireai/frontend/build/static/js/8763.61761e08.js +1 -0
  229. rapidfireai/frontend/build/static/js/8823.baf9bffd.chunk.js +2 -0
  230. rapidfireai/frontend/build/static/js/8823.baf9bffd.chunk.js.LICENSE.txt +6 -0
  231. rapidfireai/frontend/build/static/js/8867.767462b7.chunk.js +1 -0
  232. rapidfireai/frontend/build/static/js/8953.c0f88dea.chunk.js +1 -0
  233. rapidfireai/frontend/build/static/js/8960.357cb1eb.chunk.js +2 -0
  234. rapidfireai/frontend/build/static/js/8960.357cb1eb.chunk.js.LICENSE.txt +6 -0
  235. rapidfireai/frontend/build/static/js/9.f4492795.chunk.js +2 -0
  236. rapidfireai/frontend/build/static/js/9.f4492795.chunk.js.LICENSE.txt +12 -0
  237. rapidfireai/frontend/build/static/js/9079.88a8d2a3.js +1 -0
  238. rapidfireai/frontend/build/static/js/9082.37c40520.chunk.js +10 -0
  239. rapidfireai/frontend/build/static/js/9133.90ae330d.js +2 -0
  240. rapidfireai/frontend/build/static/js/9133.90ae330d.js.LICENSE.txt +8 -0
  241. rapidfireai/frontend/build/static/js/9151.1ac359d5.js +2 -0
  242. rapidfireai/frontend/build/static/js/9151.1ac359d5.js.LICENSE.txt +8 -0
  243. rapidfireai/frontend/build/static/js/9168.027bf2fd.chunk.js +1 -0
  244. rapidfireai/frontend/build/static/js/9194.9c5cc548.chunk.js +10 -0
  245. rapidfireai/frontend/build/static/js/9244.026f4aee.chunk.js +1 -0
  246. rapidfireai/frontend/build/static/js/936.2e02d037.js +2 -0
  247. rapidfireai/frontend/build/static/js/936.2e02d037.js.LICENSE.txt +6 -0
  248. rapidfireai/frontend/build/static/js/9369.7d1a0a1d.chunk.js +1 -0
  249. rapidfireai/frontend/build/static/js/9427.7c8442e7.chunk.js +1 -0
  250. rapidfireai/frontend/build/static/js/944.55948859.chunk.js +1 -0
  251. rapidfireai/frontend/build/static/js/9499.c53a82da.js +2 -0
  252. rapidfireai/frontend/build/static/js/9499.c53a82da.js.LICENSE.txt +62 -0
  253. rapidfireai/frontend/build/static/js/9531.3ce05781.chunk.js +1 -0
  254. rapidfireai/frontend/build/static/js/9547.92fac952.chunk.js +2 -0
  255. rapidfireai/frontend/build/static/js/9547.92fac952.chunk.js.LICENSE.txt +6 -0
  256. rapidfireai/frontend/build/static/js/9620.b6e973a7.chunk.js +1 -0
  257. rapidfireai/frontend/build/static/js/9645.6fddfa65.chunk.js +1 -0
  258. rapidfireai/frontend/build/static/js/9669.d38dda6d.js +1 -0
  259. rapidfireai/frontend/build/static/js/9682.41b6b807.chunk.js +1 -0
  260. rapidfireai/frontend/build/static/js/9720.19d5ae76.chunk.js +2 -0
  261. rapidfireai/frontend/build/static/js/9720.19d5ae76.chunk.js.LICENSE.txt +23 -0
  262. rapidfireai/frontend/build/static/js/9723.d3c7fe9e.js +1 -0
  263. rapidfireai/frontend/build/static/js/9780.02a27630.chunk.js +10 -0
  264. rapidfireai/frontend/build/static/js/9808.d0ca9674.chunk.js +2 -0
  265. rapidfireai/frontend/build/static/js/9808.d0ca9674.chunk.js.LICENSE.txt +6 -0
  266. rapidfireai/frontend/build/static/js/9815.b8db3c5d.js +1 -0
  267. rapidfireai/frontend/build/static/js/9886.2940b53a.chunk.js +1 -0
  268. rapidfireai/frontend/build/static/js/main~1f912138.fa9d03b1.js +1 -0
  269. rapidfireai/frontend/build/static/js/main~43dd7041.2e00860d.js +1 -0
  270. rapidfireai/frontend/build/static/js/main~84781932.68deffff.js +1 -0
  271. rapidfireai/frontend/build/static/media/404-overflow.fad9a31861b0afba6f921ebb8e769688.svg +32 -0
  272. rapidfireai/frontend/build/static/media/RapidFire_Square_Bug.27ceb48296314a4bc0d4.png +0 -0
  273. rapidfireai/frontend/build/static/media/chart-bar.0fd4a63680fba840a7b69fbf07969f79.svg +7 -0
  274. rapidfireai/frontend/build/static/media/chart-contour.0d4b306f2669f3ad25375568935e3ce3.svg +5 -0
  275. rapidfireai/frontend/build/static/media/chart-difference.16174216d6f3b7c24f40e3541fe0ca2c.svg +20 -0
  276. rapidfireai/frontend/build/static/media/chart-image.cc434c4dc50780966344e2385a15f8fe.svg +6 -0
  277. rapidfireai/frontend/build/static/media/chart-line.0adaa2036bb4eb5956db6d0c7e925a3d.svg +4 -0
  278. rapidfireai/frontend/build/static/media/chart-parallel.da7dedf539b2af4b654d377c679173e4.svg +7 -0
  279. rapidfireai/frontend/build/static/media/chart-scatter.69118d0023a6ff3973f7fa913834ac47.svg +9 -0
  280. rapidfireai/frontend/build/static/media/default-error.f246ddf367c6fbd67942e5a13382a7f1.svg +26 -0
  281. rapidfireai/frontend/build/static/media/fontawesome-webfont.1e59d2330b4c6deb84b3.ttf +0 -0
  282. rapidfireai/frontend/build/static/media/fontawesome-webfont.20fd1704ea223900efa9.woff2 +0 -0
  283. rapidfireai/frontend/build/static/media/fontawesome-webfont.8b43027f47b20503057d.eot +0 -0
  284. rapidfireai/frontend/build/static/media/fontawesome-webfont.c1e38fd9e0e74ba58f7a.svg +2671 -0
  285. rapidfireai/frontend/build/static/media/fontawesome-webfont.f691f37e57f04c152e23.woff +0 -0
  286. rapidfireai/frontend/build/static/media/icon-visible-fill.8d34cd35303828fdfc15154f5536e63b.svg +7 -0
  287. rapidfireai/frontend/build/static/media/no-experiments.0e4f4a114ef73e7d81c09474aba64b6c.svg +22 -0
  288. rapidfireai/frontend/build/static/media/parallel-chart-placeholder.234ef0c5b220ef2a5a6fa5bafff173f7.svg +16 -0
  289. rapidfireai/frontend/build/static/media/permission-denied-lock.16036747d57cd663d7df223781a447b2.svg +14 -0
  290. rapidfireai/frontend/build/static/media/promo-modal-content.e3b2c6c568ac192b9bec54b838b54850.svg +30 -0
  291. rapidfireai/frontend/build/static/media/registered-model-grey-ok.8274b58d39504c8d1b8c358aa1c9aa35.svg +23 -0
  292. rapidfireai/frontend/build/static/media/warning.290a3b14118933547965e91ea61c5a61.svg +3 -0
  293. rapidfireai/frontend/proxy_middleware.py +233 -0
  294. rapidfireai/frontend/server.py +25 -0
  295. rapidfireai/ml/__init__.py +0 -0
  296. rapidfireai/ml/callbacks.py +176 -0
  297. rapidfireai/ml/checkpoint_utils.py +540 -0
  298. rapidfireai/ml/trainer.py +309 -0
  299. rapidfireai/start.sh +634 -0
  300. rapidfireai/utils/__init__.py +0 -0
  301. rapidfireai/utils/automl_utils.py +51 -0
  302. rapidfireai/utils/constants.py +141 -0
  303. rapidfireai/utils/datapaths.py +69 -0
  304. rapidfireai/utils/exceptions.py +82 -0
  305. rapidfireai/utils/experiment_utils.py +370 -0
  306. rapidfireai/utils/logging.py +87 -0
  307. rapidfireai/utils/mlflow_manager.py +121 -0
  308. rapidfireai/utils/serialize.py +15 -0
  309. rapidfireai/utils/shm_manager.py +469 -0
  310. rapidfireai/utils/trainer_config.py +23 -0
  311. rapidfireai/utils/worker_manager.py +219 -0
  312. rapidfireai/version.py +6 -0
  313. rapidfireai-0.9.10.dist-info/METADATA +247 -0
  314. rapidfireai-0.9.10.dist-info/RECORD +318 -0
  315. rapidfireai-0.9.10.dist-info/entry_points.txt +2 -0
  316. rapidfireai-0.0.1.dist-info/METADATA +0 -37
  317. rapidfireai-0.0.1.dist-info/RECORD +0 -6
  318. {rapidfireai-0.0.1.dist-info → rapidfireai-0.9.10.dist-info}/WHEEL +0 -0
  319. {rapidfireai-0.0.1.dist-info → rapidfireai-0.9.10.dist-info}/licenses/LICENSE +0 -0
  320. {rapidfireai-0.0.1.dist-info → rapidfireai-0.9.10.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,637 @@
1
+ """This module contains the Controller class which is responsible for orchestrating the RapidFire lifecycle."""
2
+
3
+ import math
4
+ import random
5
+ import time
6
+ from logging import Logger
7
+ from pathlib import Path
8
+ from pprint import pformat
9
+ from typing import Any, Callable
10
+
11
+ import mlflow
12
+ import torch
13
+ from torch.utils.data import Dataset
14
+
15
+ from rapidfireai.automl import AutoMLAlgorithm
16
+ from rapidfireai.backend.scheduler import Scheduler
17
+ from rapidfireai.db.rf_db import RfDb
18
+ from rapidfireai.utils.automl_utils import get_flattened_config_leaf, get_runs
19
+ from rapidfireai.utils.constants import (
20
+ MLFLOW_URL,
21
+ ControllerTask,
22
+ ExperimentTask,
23
+ RunEndedBy,
24
+ RunSource,
25
+ RunStatus,
26
+ TaskStatus,
27
+ WorkerTask,
28
+ )
29
+ from rapidfireai.utils.datapaths import DataPath
30
+ from rapidfireai.utils.exceptions import ControllerException, NoGPUsFoundException
31
+ from rapidfireai.utils.logging import RFLogger
32
+ from rapidfireai.utils.mlflow_manager import MLflowManager
33
+ from rapidfireai.utils.serialize import encode_payload
34
+ from rapidfireai.utils.shm_manager import SharedMemoryManager
35
+ from rapidfireai.utils.worker_manager import WorkerManager
36
+
37
+
38
+ class Controller:
39
+ """This module contains the ML Controller class which is responsible for orchestrating the RapidFire lifecycle."""
40
+
41
+ def __init__(self, experiment_id: int, experiment_name: str) -> None:
42
+ """Initialize the controller."""
43
+ import torch.multiprocessing as mp
44
+
45
+ try:
46
+ mp.set_start_method("spawn", force=True)
47
+ except RuntimeError:
48
+ # Start method already set
49
+ pass
50
+
51
+ self.experiment_id: int = experiment_id
52
+ self.experiment_name: str = experiment_name
53
+
54
+ # create database object
55
+ self.db: RfDb = RfDb()
56
+
57
+ # create controller logger
58
+ logging = RFLogger()
59
+ self.logger: Logger = logging.create_logger("controller")
60
+ self.user_logger: Logger = logging.create_logger("user")
61
+ self.ic_logger: Logger = logging.create_logger("interactive-control")
62
+
63
+ # get number of GPUs
64
+ self.num_workers: int = torch.cuda.device_count()
65
+ if self.num_workers == 0:
66
+ raise NoGPUsFoundException("No GPUs found while initializing controller.")
67
+ self.logger.debug(f"Found {self.num_workers} workers/GPUs.")
68
+
69
+ # initialize shared manager and registry, create shared memory manager instance
70
+ self.shm_manager: SharedMemoryManager = SharedMemoryManager(name="controller-shm")
71
+ registry, process_lock = self.shm_manager.get_shm_objects()
72
+
73
+ # create worker manager
74
+ self.worker_manager: WorkerManager = WorkerManager(self.num_workers, registry, process_lock)
75
+
76
+ # create mlflow manager
77
+ self.mlflow_manager: MLflowManager = MLflowManager(MLFLOW_URL)
78
+ self.mlflow_manager.get_experiment(self.experiment_name)
79
+
80
+ self.logger.debug("Controller initialized")
81
+
82
+ def _create_models(
83
+ self,
84
+ param_config: AutoMLAlgorithm | dict[str, Any],
85
+ source: RunSource,
86
+ seed: int,
87
+ len_train_dataset: int,
88
+ num_chunks: int,
89
+ warm_start_info: dict[str, Any] | None = None,
90
+ ) -> list[int]:
91
+ """Create the models."""
92
+
93
+ # get config_leaf from param_config for each run
94
+ config_leafs = get_runs(param_config, seed)
95
+
96
+ # create runs
97
+ runs = {}
98
+ for config_leaf in config_leafs:
99
+ flattened_config = get_flattened_config_leaf(config_leaf)
100
+ # print("flattened_config: ",flattened_config)
101
+ total_steps = self._get_total_step(config_leaf, len_train_dataset, num_chunks)
102
+
103
+ run_id = self.db.create_run(
104
+ config_leaf=config_leaf,
105
+ status=RunStatus.NEW,
106
+ completed_steps=0,
107
+ total_steps=total_steps,
108
+ error="",
109
+ source=source,
110
+ ended_by=None,
111
+ start_chunk_id=warm_start_info["start_chunk_id"] if warm_start_info else 0,
112
+ warm_started_from=warm_start_info["parent_run_id"] if warm_start_info else None,
113
+ )
114
+ runs[run_id] = flattened_config
115
+
116
+ # create directories for each run
117
+ try:
118
+ base_run_path = DataPath.base_run_path(run_id)
119
+ work_dir_path = DataPath.work_dir_path(base_run_path)
120
+ initial_checkpoint_path = DataPath.initial_checkpoint_path(base_run_path)
121
+ final_checkpoint_path = DataPath.final_checkpoint_path(base_run_path)
122
+ intermediate_checkpoint_path = DataPath.intermediate_checkpoint_path(base_run_path)
123
+
124
+ Path.mkdir(work_dir_path, parents=True, exist_ok=True)
125
+ Path.mkdir(initial_checkpoint_path, parents=True, exist_ok=True)
126
+ Path.mkdir(final_checkpoint_path, parents=True, exist_ok=True)
127
+ Path.mkdir(intermediate_checkpoint_path, parents=True, exist_ok=True)
128
+ except (PermissionError, OSError) as e:
129
+ raise ControllerException(f"Failed to create required Run DataPath directories: {e}") from e
130
+
131
+ # create new MlFlow run
132
+ try:
133
+ # create new MlFlow run and get the mlflow_run_id
134
+ mlflow_run_id = self.mlflow_manager.create_run(str(run_id))
135
+
136
+ # populate MLFlow with model config info
137
+ for key, value in flattened_config.items():
138
+ self.mlflow_manager.log_param(mlflow_run_id, key, value)
139
+ if warm_start_info:
140
+ self.mlflow_manager.log_param(mlflow_run_id, "warm-start", str(warm_start_info))
141
+ self.logger.debug(f"Populated MLFlow with model config info for run {run_id}.")
142
+ self.db.set_run_details(
143
+ run_id=run_id,
144
+ mlflow_run_id=mlflow_run_id,
145
+ flattened_config=flattened_config,
146
+ )
147
+ except mlflow.exceptions.MlflowException as e:
148
+ msg = f"Error creating new MLFlow run for run {run_id} - {e}."
149
+ print(msg)
150
+ self.mlflow_manager.end_run(mlflow_run_id)
151
+ self.logger.error(msg, exc_info=True)
152
+
153
+ total_runs = len(runs)
154
+ self.logger.info(f"Created {total_runs} runs - \n{pformat(runs, indent=4, width=120)}")
155
+ self.logger.debug(f"Got {total_runs} runs for {source.value}.")
156
+
157
+ # set experiment task to run_fit
158
+ self.db.set_experiment_current_task(ExperimentTask.RUN_FIT)
159
+ self.logger.debug("Completed creating models.")
160
+
161
+ return list(runs.keys())
162
+
163
+ def _clear_run_from_shm(self, run_id: int) -> None:
164
+ """Clear the run from shared memory."""
165
+
166
+ # check if there are any other runs with the same base model
167
+ base_model_name = self.db.get_run(run_id)["config_leaf"]["model_name"]
168
+ relevant_runs = self.db.get_runs_by_status([RunStatus.ONGOING, RunStatus.NEW, RunStatus.STOPPED])
169
+
170
+ # get shared object types to delete - if no other runs are using it
171
+ delete_shared_objects = True
172
+ for r_run_id, r_run_details in relevant_runs.items():
173
+ if r_run_details["config_leaf"]["model_name"] == base_model_name and r_run_id != run_id:
174
+ delete_shared_objects = False
175
+ break
176
+
177
+ # delete model object from shared memory
178
+ self.shm_manager.delete_model_object(run_id, base_model_name if delete_shared_objects else None)
179
+
180
+ def _process_interactive_control(
181
+ self,
182
+ run_states: dict[str, Any],
183
+ clone_modify_tasks: list[dict[str, Any]],
184
+ len_train_dataset: int,
185
+ seed: int,
186
+ num_chunks: int,
187
+ ) -> None:
188
+ """Process interactive control tasks."""
189
+
190
+ # process non-clone_modify tasks
191
+ for run_id, run_state in run_states.items():
192
+ if not run_state["task_id"]:
193
+ continue
194
+
195
+ if run_state["status"] == RunStatus.STOPPED:
196
+ # process stopped tasks
197
+ # mark run as stopped
198
+ self.db.set_run_details(
199
+ run_id=run_id,
200
+ status=RunStatus.STOPPED,
201
+ ended_by=RunEndedBy.INTERACTIVE_CONTROL,
202
+ )
203
+ self.db.set_ic_ops_task_status(run_state["task_id"], TaskStatus.COMPLETED)
204
+ self.ic_logger.info(f"Stopping run {run_id} by Interactive Control")
205
+ elif run_state["status"] == RunStatus.DELETED:
206
+ # process deleted tasks
207
+ # clear run from shm
208
+ self._clear_run_from_shm(run_id)
209
+ # delete run from MLFlow
210
+ mlflow_run_id = self.db.get_run(run_id)["mlflow_run_id"]
211
+ self.mlflow_manager.delete_run(mlflow_run_id)
212
+ # mark run as deleted
213
+ self.db.set_run_details(
214
+ run_id=run_id,
215
+ status=RunStatus.DELETED,
216
+ ended_by=RunEndedBy.INTERACTIVE_CONTROL,
217
+ )
218
+ self.db.set_ic_ops_task_status(run_state["task_id"], TaskStatus.COMPLETED)
219
+ self.ic_logger.info(f"Deleting run {run_id} by Interactive Control")
220
+ elif run_state["status"] == RunStatus.ONGOING:
221
+ # process ongoing tasks
222
+ self.db.set_run_details(
223
+ run_id=run_id,
224
+ status=RunStatus.ONGOING,
225
+ ended_by="",
226
+ )
227
+ self.db.set_ic_ops_task_status(run_state["task_id"], TaskStatus.COMPLETED)
228
+ self.ic_logger.info(f"Resuming run {run_id} by Interactive Control")
229
+ elif run_state["status"] == RunStatus.COMPLETED:
230
+ # process completed tasks
231
+ self.logger.warning(f"Run {run_id} is already completed. Skipping Interactive Control task.")
232
+ self.db.set_ic_ops_task_status(run_state["task_id"], TaskStatus.SKIPPED)
233
+ else:
234
+ raise ValueError(f"Unsupported run status {run_state['status']}")
235
+
236
+ # process clone_modify tasks from the collected list
237
+ for task in clone_modify_tasks:
238
+ parent_run_id, ic_op, config_leaf = (
239
+ task["run_id"],
240
+ task["ic_op"],
241
+ task["config_leaf"],
242
+ )
243
+
244
+ # add additional_kwargs to config_leaf if it exists in the parent run
245
+ parent_run_details = self.db.get_run(parent_run_id)
246
+ if "additional_kwargs" in parent_run_details["config_leaf"]:
247
+ config_leaf["additional_kwargs"] = parent_run_details["config_leaf"]["additional_kwargs"]
248
+
249
+ # create model for the new run
250
+ try:
251
+ if ic_op == ControllerTask.IC_CLONE_MODIFY:
252
+ run_ids = self._create_models(
253
+ config_leaf,
254
+ RunSource.INTERACTIVE_CONTROL,
255
+ seed,
256
+ len_train_dataset,
257
+ num_chunks=num_chunks,
258
+ )
259
+ elif ic_op == ControllerTask.IC_CLONE_MODIFY_WARM:
260
+ warm_start_info = {
261
+ "parent_run_id": parent_run_id,
262
+ "start_chunk_id": parent_run_details["num_chunks_visited_curr_epoch"],
263
+ }
264
+ run_ids = self._create_models(
265
+ config_leaf,
266
+ RunSource.INTERACTIVE_CONTROL,
267
+ seed,
268
+ len_train_dataset,
269
+ num_chunks,
270
+ warm_start_info,
271
+ )
272
+ else:
273
+ raise ValueError(f"Unsupported IC operation {ic_op}")
274
+
275
+ # mark task as completed
276
+ self.db.set_ic_ops_task_status(task["task_id"], TaskStatus.COMPLETED)
277
+ self.ic_logger.info(
278
+ f"Cloned run {parent_run_id} by Interactive Control with {ic_op.value} into runs - {run_ids}"
279
+ )
280
+ except Exception as e:
281
+ self.db.set_ic_ops_task_status(task["task_id"], TaskStatus.FAILED)
282
+ self.ic_logger.error(f"Error creating model for run {parent_run_id}: {e}")
283
+ raise ControllerException(f"Error creating model for run {parent_run_id}: {e}") from e
284
+
285
+ def _process_interm_ic_ops_states(
286
+ self,
287
+ currently_scheduled_runs: list[int],
288
+ ) -> tuple[dict[str, Any], list[dict[str, Any]]]:
289
+ """Process the interactive control."""
290
+ # get IC Ops scheduled tasks
291
+ ic_scheduled_tasks = self.db.get_scheduled_ic_ops_tasks()
292
+
293
+ # track states for each task(run) and collect clone_modify tasks separately
294
+ run_states = {}
295
+ clone_modify_tasks = []
296
+ for task in ic_scheduled_tasks:
297
+ run_id = task["run_id"]
298
+
299
+ # skip if run is currently scheduled (we process IC ops only at chunk boundaries)
300
+ if run_id in currently_scheduled_runs:
301
+ # self.logger.debug(f"Skipping IC op for run {run_id} as it is currently scheduled")
302
+ continue
303
+
304
+ is_clone_modify_task = task["ic_op"] in (
305
+ ControllerTask.IC_CLONE_MODIFY,
306
+ ControllerTask.IC_CLONE_MODIFY_WARM,
307
+ )
308
+
309
+ if is_clone_modify_task:
310
+ # clone_modify tasks
311
+ # get latest run state
312
+ run_status = run_states[run_id]["status"] if run_id in run_states else self.db.get_run(run_id)["status"]
313
+
314
+ # track clone_modify tasks only for non-deleted runs
315
+ if run_status != RunStatus.DELETED:
316
+ clone_modify_tasks.append(task)
317
+ self.ic_logger.info(f"Added {task['ic_op']} task for run {run_id}.")
318
+ else:
319
+ self.db.set_ic_ops_task_status(task["task_id"], TaskStatus.SKIPPED)
320
+ self.ic_logger.warning(f"Skipping {task['ic_op']} task for deleted run {run_id}.")
321
+ else:
322
+ # Non clone_modify tasks
323
+ if run_id not in run_states:
324
+ run_states[run_id] = {
325
+ "task_id": None,
326
+ "task": None,
327
+ "status": self.db.get_run(run_id)["status"],
328
+ }
329
+
330
+ # update run states based on existing status and task
331
+ current_status = run_states[run_id]["status"]
332
+ if current_status == RunStatus.COMPLETED and task["ic_op"] in [
333
+ ControllerTask.IC_RESUME,
334
+ ControllerTask.IC_STOP,
335
+ ]:
336
+ # ignore RESUME/STOP tasks for completed runs
337
+ self.ic_logger.warning(f"Ignoring RESUME/STOP task for run {run_id} as it is already completed")
338
+ self.db.set_ic_ops_task_status(task["task_id"], TaskStatus.SKIPPED)
339
+ elif current_status == RunStatus.FAILED and task["ic_op"] != ControllerTask.IC_DELETE:
340
+ # ignore all tasks except DELETE for failed runs
341
+ self.ic_logger.warning(f"Ignoring task {task['ic_op'].value} for failed run {run_id}")
342
+ self.db.set_ic_ops_task_status(task["task_id"], TaskStatus.SKIPPED)
343
+ elif current_status == RunStatus.DELETED:
344
+ # ignore all tasks for deleted runs
345
+ self.ic_logger.warning(f"Ignoring task {task['ic_op'].value} for deleted run {run_id}")
346
+ self.db.set_ic_ops_task_status(task["task_id"], TaskStatus.SKIPPED)
347
+ else:
348
+ # valid ic_op for this run
349
+ # mark prev task as completed
350
+ if run_states[run_id]["task_id"] is not None:
351
+ self.db.set_ic_ops_task_status(run_states[run_id]["task_id"], TaskStatus.COMPLETED)
352
+
353
+ # add new task to run states
354
+ if task["ic_op"] == ControllerTask.IC_STOP:
355
+ updated_status = RunStatus.STOPPED
356
+ info_msg = f"Received STOP task for run {run_id}"
357
+ elif task["ic_op"] == ControllerTask.IC_DELETE:
358
+ updated_status = RunStatus.DELETED
359
+ info_msg = f"Received DELETE task for run {run_id}"
360
+ elif task["ic_op"] == ControllerTask.IC_RESUME:
361
+ updated_status = RunStatus.ONGOING
362
+ info_msg = f"Received RESUME task for run {run_id}"
363
+ else:
364
+ self.db.set_ic_ops_task_status(task["task_id"], TaskStatus.FAILED)
365
+ raise ValueError(f"Unsupported task {task['ic_op']}")
366
+ run_states[run_id].update(
367
+ {
368
+ "task_id": task["task_id"],
369
+ "task": task["ic_op"],
370
+ "status": updated_status if updated_status else current_status,
371
+ }
372
+ )
373
+ self.ic_logger.info(info_msg)
374
+
375
+ return run_states, clone_modify_tasks
376
+
377
+ def _get_total_step(self, config_leaf: dict[str, Any], len_train_dataset: int, num_chunks: int) -> int:
378
+ """Get the total number of steps for a run."""
379
+ num_train_epochs = config_leaf["training_args"].get("num_train_epochs", 1)
380
+
381
+ total_steps = 0
382
+ # max_steps overrides num_train_epochs
383
+ if config_leaf["training_args"].get("max_steps", None):
384
+ # ceil to nearest chunk multiple
385
+ total_steps = config_leaf["training_args"]["max_steps"]
386
+ elif num_train_epochs:
387
+ # total_steps = num_epochs to num_steps =
388
+ per_device_train_batch_size = config_leaf["training_args"].get("per_device_train_batch_size", 1)
389
+ gradient_accumulation_steps = config_leaf["training_args"].get("gradient_accumulation_steps", 1)
390
+ total_steps = (
391
+ math.ceil(len_train_dataset / (num_chunks * per_device_train_batch_size * gradient_accumulation_steps))
392
+ * num_chunks
393
+ * num_train_epochs
394
+ )
395
+ return total_steps
396
+
397
+ def run_fit(
398
+ self,
399
+ param_config: Any,
400
+ create_model_fn: Callable,
401
+ train_dataset: Dataset,
402
+ eval_dataset: Dataset,
403
+ num_chunks: int,
404
+ seed: int = 42,
405
+ ) -> None:
406
+ """Run the fit."""
407
+
408
+ # set experiment task to create models
409
+ self.db.set_experiment_current_task(ExperimentTask.CREATE_MODELS)
410
+ self.logger.debug(f"Set experiment task to {ExperimentTask.CREATE_MODELS.value}.")
411
+
412
+ # save train and eval dataset objects to a file for workers to load
413
+ try:
414
+ datasets = {
415
+ "train": train_dataset,
416
+ "eval": eval_dataset if eval_dataset else None,
417
+ "num_chunks": num_chunks,
418
+ }
419
+ with open(DataPath.dataset_path(), "w", encoding="utf-8") as f:
420
+ f.write(encode_payload(datasets))
421
+ self.logger.debug(f"Saved datasets to {DataPath.dataset_path()}")
422
+ except Exception as e:
423
+ raise ControllerException(f"Error saving datasets: {e}") from e
424
+
425
+ # set seed
426
+ random.seed(seed)
427
+ self.logger.info(f"Set seed to {seed}")
428
+
429
+ # create models
430
+ try:
431
+ len_train_dataset = len(train_dataset)
432
+ self._create_models(param_config, RunSource.INITIAL, seed, len_train_dataset, num_chunks=num_chunks)
433
+ self.logger.debug("Created models.")
434
+ except Exception as e:
435
+ raise ControllerException(f"Error creating models: {e}") from e
436
+
437
+ # set experiment task to create models
438
+ self.db.set_experiment_current_task(ExperimentTask.RUN_FIT)
439
+ self.logger.debug(f"Set experiment task to {ExperimentTask.RUN_FIT.value}.")
440
+
441
+ # create workers
442
+ try:
443
+ self.worker_manager.create_workers()
444
+ print("Created workers")
445
+ self.logger.debug(f"Created {self.num_workers} workers.")
446
+ except Exception as e:
447
+ raise ControllerException(f"Error creating workers: {e}") from e
448
+
449
+ # create scheduler
450
+ run_ids = list(
451
+ self.db.get_runs_by_status(
452
+ [
453
+ RunStatus.NEW,
454
+ ]
455
+ ).keys()
456
+ )
457
+ scheduler = Scheduler(run_ids, self.num_workers, num_chunks)
458
+
459
+ # run fit
460
+ self.logger.info("Starting Training and Validation")
461
+ try:
462
+ all_done = False
463
+ prev_worker_tasks = {} # Track previous iteration's worker tasks
464
+
465
+ while not all_done:
466
+ # check for errors
467
+ exp_error = self.db.get_experiment_error()
468
+ if exp_error:
469
+ print(f"Error in experiment: {exp_error}")
470
+ self.logger.error(f"Error in experiment: {exp_error}")
471
+ break
472
+
473
+ # get current state (pre IC ops states)
474
+ all_worker_tasks = self.db.get_all_worker_tasks()
475
+ all_run_details = self.db.get_all_runs()
476
+
477
+ # Filter and separate fresh completed and failed tasks in a single loop
478
+ completed_tasks = {}
479
+ failed_tasks = []
480
+ for worker_id, worker_task in all_worker_tasks.items():
481
+ prev_task = prev_worker_tasks.get(worker_id, {})
482
+ current_task_tuple = (worker_task["task_id"], worker_task["status"])
483
+ prev_task_tuple = (prev_task.get("task_id"), prev_task.get("status"))
484
+
485
+ # skip if task is the same as previous iteration (no change in status) or run is not active
486
+ if current_task_tuple == prev_task_tuple or worker_task["run_id"] not in scheduler.run_ids:
487
+ continue
488
+
489
+ if worker_task["status"] == TaskStatus.COMPLETED:
490
+ completed_tasks[worker_id] = worker_task
491
+ elif worker_task["status"] == TaskStatus.FAILED:
492
+ failed_tasks.append(worker_task)
493
+
494
+ # Process completed tasks first (before scheduling new ones)
495
+ for worker_id, worker_task in completed_tasks.items():
496
+ run_id = worker_task["run_id"]
497
+ chunk_id = worker_task["chunk_id"]
498
+ run_details = all_run_details[run_id]
499
+ self.logger.debug(f"Completed task: run {run_id}, chunk {chunk_id} on worker {worker_id}")
500
+ self.logger.info(
501
+ f"Run {run_id} completed steps - {run_details['completed_steps']}/{run_details['total_steps']}"
502
+ )
503
+
504
+ # Update scheduler state
505
+ scheduler.set_completed_task(worker_id)
506
+
507
+ # Update database state and local state using scheduler's state as source of truth
508
+ new_chunks_visited = scheduler.run_visited_num_chunks[run_id]
509
+ if new_chunks_visited == num_chunks:
510
+ num_epochs_completed = run_details["num_epochs_completed"] + 1
511
+ else:
512
+ num_epochs_completed = run_details["num_epochs_completed"]
513
+ self.db.set_run_details(
514
+ run_id=run_id,
515
+ num_chunks_visited_curr_epoch=new_chunks_visited,
516
+ num_epochs_completed=num_epochs_completed,
517
+ )
518
+
519
+ # Update progress
520
+ progress_percentage = (
521
+ (run_details["completed_steps"] / run_details["total_steps"] * 100)
522
+ if run_details["total_steps"] > 0
523
+ else 0
524
+ )
525
+ self.db.set_controller_progress(run_id, progress_percentage)
526
+
527
+ # Check if run has completed all epochs
528
+ # completed_steps can go beyond total_steps since we stop only at a chunk boundary
529
+ if run_details["completed_steps"] >= run_details["total_steps"]:
530
+ scheduler.remove_run(run_id)
531
+ self.db.set_run_details(
532
+ run_id=run_id,
533
+ status=RunStatus.COMPLETED,
534
+ ended_by=RunEndedBy.EPOCH_COMPLETED,
535
+ )
536
+ self.logger.info(
537
+ f"Run {run_id} has completed all its epochs - "
538
+ f"steps {run_details['completed_steps']}/{run_details['total_steps']}"
539
+ )
540
+ # Check if run has completed only current epoch (hasn't reached total_steps yet)
541
+ elif (
542
+ new_chunks_visited == num_chunks and run_details["completed_steps"] < run_details["total_steps"]
543
+ ):
544
+ scheduler.reset_run(run_id)
545
+ self.db.set_run_details(run_id=run_id, num_chunks_visited_curr_epoch=0)
546
+ self.logger.info(f"Run {run_id} has completed epoch ({new_chunks_visited}/{num_chunks} chunks)")
547
+
548
+ # Check for failed runs and update scheduler, local state, shm
549
+ for worker_task in failed_tasks:
550
+ run_id = worker_task["run_id"]
551
+ run_error = all_run_details[run_id]["error"]
552
+ if run_id in scheduler.run_ids:
553
+ scheduler.remove_run(run_id)
554
+ self._clear_run_from_shm(run_id)
555
+ err_msg = f"Run {run_id} has failed: {run_error}"
556
+ print(err_msg)
557
+ self.logger.error(err_msg)
558
+ self.logger.debug(f"Removed run {run_id} from scheduler")
559
+
560
+ # Process interactive control tasks (this fetches latest run states internally)
561
+ try:
562
+ currently_scheduled_runs = list(scheduler.worker_running_current_run.values())
563
+ run_states, clone_modify_tasks = self._process_interm_ic_ops_states(currently_scheduled_runs)
564
+ self._process_interactive_control(
565
+ run_states, clone_modify_tasks, len_train_dataset, seed, num_chunks
566
+ )
567
+ except Exception as e:
568
+ raise ControllerException(f"Error processing interactive control tasks: {e}") from e
569
+
570
+ # fetch latest run states again (post IC ops states)
571
+ all_run_details = self.db.get_all_runs()
572
+
573
+ # Update scheduler with active and inactive runs from IC Ops changes
574
+ for run_id, run_details in all_run_details.items():
575
+ # add active runs to scheduler
576
+ if run_details["status"] in (RunStatus.ONGOING, RunStatus.NEW) and run_id not in scheduler.run_ids:
577
+ chunks_visited = all_run_details[run_id]["num_chunks_visited_curr_epoch"]
578
+ start_chunk_id = all_run_details[run_id]["start_chunk_id"]
579
+ scheduler.add_run(run_id, chunks_visited, start_chunk_id)
580
+ self.logger.debug(f"Added run {run_id} to scheduler with {chunks_visited} chunks visited")
581
+ # remove inactive runs from scheduler
582
+ elif (
583
+ run_details["status"] in (RunStatus.STOPPED, RunStatus.DELETED) and run_id in scheduler.run_ids
584
+ ):
585
+ scheduler.remove_run(run_id)
586
+ self.logger.debug(f"Removed run {run_id} from scheduler")
587
+
588
+ # Get schedule from scheduler
589
+ schedule = scheduler.schedule()
590
+ run_id = schedule["run_id"]
591
+ worker_id = schedule["worker_id"]
592
+ chunk_id = schedule["chunk_id"]
593
+
594
+ # Check termination condition
595
+ if run_id is None and worker_id is None and chunk_id is None:
596
+ self.logger.info("Scheduler indicates all runs have completed all chunks")
597
+ all_done = True
598
+ break
599
+
600
+ # Check if no schedule possible
601
+ if run_id == -1 and worker_id == -1 and chunk_id == -1:
602
+ # self.logger.debug("No schedule possible - all workers busy or no available runs")
603
+ time.sleep(1)
604
+ continue
605
+
606
+ # Execute Schedule
607
+ # Create worker task
608
+ # self.logger.debug(f"Scheduler schedule: {schedule}")
609
+ self.db.set_run_details(run_id=run_id, status=RunStatus.ONGOING)
610
+ self.db.create_worker_task(
611
+ worker_id,
612
+ WorkerTask.TRAIN_VAL,
613
+ TaskStatus.SCHEDULED,
614
+ run_id,
615
+ chunk_id,
616
+ config_options={"create_model_fn": create_model_fn},
617
+ )
618
+ self.logger.debug(f"Scheduled run {run_id} on worker {worker_id} for chunk {chunk_id}")
619
+
620
+ # Small delay
621
+ time.sleep(1)
622
+
623
+ # Update prev_worker_tasks for next iteration (only track task_id and status)
624
+ prev_worker_tasks = {
625
+ worker_id: {"task_id": worker_task["task_id"], "status": worker_task["status"]}
626
+ for worker_id, worker_task in all_worker_tasks.items()
627
+ }
628
+
629
+ # set experiment task to idle
630
+ self.db.set_experiment_current_task(ExperimentTask.IDLE)
631
+ self.logger.debug(f"Set experiment task to {ExperimentTask.IDLE.value}.")
632
+
633
+ except Exception as e:
634
+ raise ControllerException(f"Error during run_fit: {e}") from e
635
+
636
+ # shutdown workers
637
+ self.worker_manager.shutdown()