rapidfireai 0.0.1__py3-none-any.whl → 0.9.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rapidfireai might be problematic. Click here for more details.

Files changed (320) hide show
  1. rapidfireai/__init__.py +11 -5
  2. rapidfireai/automl/__init__.py +20 -0
  3. rapidfireai/automl/base.py +48 -0
  4. rapidfireai/automl/datatypes.py +42 -0
  5. rapidfireai/automl/grid_search.py +125 -0
  6. rapidfireai/automl/model_config.py +102 -0
  7. rapidfireai/automl/random_search.py +145 -0
  8. rapidfireai/backend/__init__.py +0 -0
  9. rapidfireai/backend/chunks.py +63 -0
  10. rapidfireai/backend/controller.py +637 -0
  11. rapidfireai/backend/scheduler.py +137 -0
  12. rapidfireai/backend/worker.py +272 -0
  13. rapidfireai/cli.py +380 -0
  14. rapidfireai/db/__init__.py +0 -0
  15. rapidfireai/db/db_interface.py +135 -0
  16. rapidfireai/db/rf_db.py +694 -0
  17. rapidfireai/db/tables.sql +64 -0
  18. rapidfireai/dispatcher/dispatcher.py +391 -0
  19. rapidfireai/dispatcher/gunicorn.conf.py +25 -0
  20. rapidfireai/experiment.py +168 -0
  21. rapidfireai/frontend/build/asset-manifest.json +276 -0
  22. rapidfireai/frontend/build/favicon.ico +0 -0
  23. rapidfireai/frontend/build/index.html +1 -0
  24. rapidfireai/frontend/build/manifest.json +15 -0
  25. rapidfireai/frontend/build/pdf.worker.js +1 -0
  26. rapidfireai/frontend/build/report.html +39 -0
  27. rapidfireai/frontend/build/static/css/1482.3b7bf531.chunk.css +1 -0
  28. rapidfireai/frontend/build/static/css/2730.3f8937ff.chunk.css +1 -0
  29. rapidfireai/frontend/build/static/css/318.0def90a7.css +7 -0
  30. rapidfireai/frontend/build/static/css/4762.9b7b71f7.chunk.css +1 -0
  31. rapidfireai/frontend/build/static/css/4950.487ecc8b.chunk.css +1 -0
  32. rapidfireai/frontend/build/static/css/5170.2574ce9d.chunk.css +1 -0
  33. rapidfireai/frontend/build/static/css/6121.4d541986.chunk.css +1 -0
  34. rapidfireai/frontend/build/static/css/6343.dd6979f2.chunk.css +1 -0
  35. rapidfireai/frontend/build/static/css/6534.433c213f.chunk.css +1 -0
  36. rapidfireai/frontend/build/static/css/6920.ffac4b2a.css +2 -0
  37. rapidfireai/frontend/build/static/css/7246.bf2f0c87.css +9 -0
  38. rapidfireai/frontend/build/static/css/7367.dd6979f2.chunk.css +1 -0
  39. rapidfireai/frontend/build/static/css/8690.05d081e5.chunk.css +1 -0
  40. rapidfireai/frontend/build/static/css/9531.d0910d3c.chunk.css +1 -0
  41. rapidfireai/frontend/build/static/css/9780.363e4943.chunk.css +1 -0
  42. rapidfireai/frontend/build/static/css/main~d91a9049.c0be472c.css +1 -0
  43. rapidfireai/frontend/build/static/js/1000.e5ed264b.chunk.js +1 -0
  44. rapidfireai/frontend/build/static/js/1012.ac98ab59.chunk.js +1 -0
  45. rapidfireai/frontend/build/static/js/1079.6c13ac0d.js +1 -0
  46. rapidfireai/frontend/build/static/js/110.9059f3b8.chunk.js +1 -0
  47. rapidfireai/frontend/build/static/js/1142.872d0010.chunk.js +1 -0
  48. rapidfireai/frontend/build/static/js/1167.9a6da14c.chunk.js +1 -0
  49. rapidfireai/frontend/build/static/js/1248.60890b4f.chunk.js +1 -0
  50. rapidfireai/frontend/build/static/js/1262.83dc7673.chunk.js +1 -0
  51. rapidfireai/frontend/build/static/js/1273.56da3e13.chunk.js +2 -0
  52. rapidfireai/frontend/build/static/js/1273.56da3e13.chunk.js.LICENSE.txt +9 -0
  53. rapidfireai/frontend/build/static/js/1303.7d19305c.chunk.js +1 -0
  54. rapidfireai/frontend/build/static/js/1351.45076ff3.chunk.js +1 -0
  55. rapidfireai/frontend/build/static/js/1355.b896a592.js +1 -0
  56. rapidfireai/frontend/build/static/js/1357.02c46a02.chunk.js +1 -0
  57. rapidfireai/frontend/build/static/js/1470.c51d60c6.chunk.js +1 -0
  58. rapidfireai/frontend/build/static/js/1482.23b74f50.chunk.js +1 -0
  59. rapidfireai/frontend/build/static/js/1500.19799d8d.chunk.js +1 -0
  60. rapidfireai/frontend/build/static/js/1648.d3b9edc7.chunk.js +1 -0
  61. rapidfireai/frontend/build/static/js/1860.7d96e3f9.chunk.js +1 -0
  62. rapidfireai/frontend/build/static/js/1909.5b1d9ff4.chunk.js +1 -0
  63. rapidfireai/frontend/build/static/js/1928.44245110.chunk.js +2 -0
  64. rapidfireai/frontend/build/static/js/1928.44245110.chunk.js.LICENSE.txt +11 -0
  65. rapidfireai/frontend/build/static/js/1933.deba26ca.chunk.js +1 -0
  66. rapidfireai/frontend/build/static/js/21.aac92802.chunk.js +1 -0
  67. rapidfireai/frontend/build/static/js/2103.0ca12071.chunk.js +1 -0
  68. rapidfireai/frontend/build/static/js/2258.b3b8fab4.chunk.js +1 -0
  69. rapidfireai/frontend/build/static/js/2289.9ad51e87.chunk.js +1 -0
  70. rapidfireai/frontend/build/static/js/2323.7dd927d7.js +2 -0
  71. rapidfireai/frontend/build/static/js/2323.7dd927d7.js.LICENSE.txt +1 -0
  72. rapidfireai/frontend/build/static/js/2346.ed99ca72.chunk.js +1 -0
  73. rapidfireai/frontend/build/static/js/2386.0a660834.chunk.js +1 -0
  74. rapidfireai/frontend/build/static/js/2402.465048f9.chunk.js +1 -0
  75. rapidfireai/frontend/build/static/js/243.5a83bbca.chunk.js +1 -0
  76. rapidfireai/frontend/build/static/js/2589.68571e16.js +1 -0
  77. rapidfireai/frontend/build/static/js/2647.65092bab.chunk.js +1 -0
  78. rapidfireai/frontend/build/static/js/2691.65d4a4e7.js +1 -0
  79. rapidfireai/frontend/build/static/js/2730.b38dd6f3.chunk.js +1 -0
  80. rapidfireai/frontend/build/static/js/2746.ef752da4.chunk.js +1 -0
  81. rapidfireai/frontend/build/static/js/2779.580d4491.chunk.js +1 -0
  82. rapidfireai/frontend/build/static/js/2799.fe5993b2.chunk.js +1 -0
  83. rapidfireai/frontend/build/static/js/2844.9708db79.chunk.js +2 -0
  84. rapidfireai/frontend/build/static/js/2844.9708db79.chunk.js.LICENSE.txt +21 -0
  85. rapidfireai/frontend/build/static/js/2901.ee0c606b.chunk.js +1 -0
  86. rapidfireai/frontend/build/static/js/2932.7cc0689b.chunk.js +2 -0
  87. rapidfireai/frontend/build/static/js/2932.7cc0689b.chunk.js.LICENSE.txt +6 -0
  88. rapidfireai/frontend/build/static/js/2956.a393c8cc.chunk.js +1 -0
  89. rapidfireai/frontend/build/static/js/2972.679bed05.chunk.js +1 -0
  90. rapidfireai/frontend/build/static/js/2985.7e51cdfa.chunk.js +2 -0
  91. rapidfireai/frontend/build/static/js/2985.7e51cdfa.chunk.js.LICENSE.txt +51 -0
  92. rapidfireai/frontend/build/static/js/3093.488df653.js +1 -0
  93. rapidfireai/frontend/build/static/js/3145.66ee61b9.js +1 -0
  94. rapidfireai/frontend/build/static/js/3170.a22f966a.chunk.js +2 -0
  95. rapidfireai/frontend/build/static/js/3170.a22f966a.chunk.js.LICENSE.txt +21 -0
  96. rapidfireai/frontend/build/static/js/3307.f6fb258c.chunk.js +1 -0
  97. rapidfireai/frontend/build/static/js/3325.d5b03d65.js +1 -0
  98. rapidfireai/frontend/build/static/js/3334.2d6704df.chunk.js +2 -0
  99. rapidfireai/frontend/build/static/js/3334.2d6704df.chunk.js.LICENSE.txt +6 -0
  100. rapidfireai/frontend/build/static/js/3387.bb8edad3.chunk.js +1 -0
  101. rapidfireai/frontend/build/static/js/3448.438e6579.chunk.js +1 -0
  102. rapidfireai/frontend/build/static/js/3460.735eea87.chunk.js +1 -0
  103. rapidfireai/frontend/build/static/js/3505.7fd3921a.js +2 -0
  104. rapidfireai/frontend/build/static/js/3505.7fd3921a.js.LICENSE.txt +9 -0
  105. rapidfireai/frontend/build/static/js/3510.cd167a00.js +2 -0
  106. rapidfireai/frontend/build/static/js/3510.cd167a00.js.LICENSE.txt +18 -0
  107. rapidfireai/frontend/build/static/js/3563.cc828e19.chunk.js +1 -0
  108. rapidfireai/frontend/build/static/js/359.08960b84.chunk.js +2 -0
  109. rapidfireai/frontend/build/static/js/359.08960b84.chunk.js.LICENSE.txt +4 -0
  110. rapidfireai/frontend/build/static/js/3608.403b4b79.chunk.js +1 -0
  111. rapidfireai/frontend/build/static/js/3652.cb8add7f.js +1 -0
  112. rapidfireai/frontend/build/static/js/3775.5230b157.chunk.js +1 -0
  113. rapidfireai/frontend/build/static/js/3817.53555d18.js +2 -0
  114. rapidfireai/frontend/build/static/js/3817.53555d18.js.LICENSE.txt +18 -0
  115. rapidfireai/frontend/build/static/js/3835.d9946ff9.chunk.js +1 -0
  116. rapidfireai/frontend/build/static/js/3964.874f0297.chunk.js +1 -0
  117. rapidfireai/frontend/build/static/js/3968.275cbc3d.chunk.js +1 -0
  118. rapidfireai/frontend/build/static/js/3999.765cbd82.chunk.js +1 -0
  119. rapidfireai/frontend/build/static/js/4020.4452c046.chunk.js +1 -0
  120. rapidfireai/frontend/build/static/js/4138.2f6f6d9f.js +1 -0
  121. rapidfireai/frontend/build/static/js/4160.f424554c.js +1 -0
  122. rapidfireai/frontend/build/static/js/4180.50cea095.chunk.js +1 -0
  123. rapidfireai/frontend/build/static/js/4221.b0bba3f5.chunk.js +1 -0
  124. rapidfireai/frontend/build/static/js/4250.5bb49278.chunk.js +1 -0
  125. rapidfireai/frontend/build/static/js/4297.15777d8f.chunk.js +1 -0
  126. rapidfireai/frontend/build/static/js/4349.c965f2de.js +2 -0
  127. rapidfireai/frontend/build/static/js/4349.c965f2de.js.LICENSE.txt +1 -0
  128. rapidfireai/frontend/build/static/js/4484.4cbe5e7f.js +2 -0
  129. rapidfireai/frontend/build/static/js/4484.4cbe5e7f.js.LICENSE.txt +10 -0
  130. rapidfireai/frontend/build/static/js/4578.a8124588.js +1 -0
  131. rapidfireai/frontend/build/static/js/4596.89a97480.js +1 -0
  132. rapidfireai/frontend/build/static/js/4748.566f435a.chunk.js +1 -0
  133. rapidfireai/frontend/build/static/js/4762.928e8a90.chunk.js +1 -0
  134. rapidfireai/frontend/build/static/js/4768.7945be63.js +2 -0
  135. rapidfireai/frontend/build/static/js/4768.7945be63.js.LICENSE.txt +1 -0
  136. rapidfireai/frontend/build/static/js/4804.26b50dd4.chunk.js +1 -0
  137. rapidfireai/frontend/build/static/js/4850.62390a45.chunk.js +1 -0
  138. rapidfireai/frontend/build/static/js/4862.a0ccb221.chunk.js +1 -0
  139. rapidfireai/frontend/build/static/js/491.5dc8ed40.chunk.js +1 -0
  140. rapidfireai/frontend/build/static/js/492.9262f038.chunk.js +2 -0
  141. rapidfireai/frontend/build/static/js/492.9262f038.chunk.js.LICENSE.txt +6 -0
  142. rapidfireai/frontend/build/static/js/4943.6d345fd3.chunk.js +1 -0
  143. rapidfireai/frontend/build/static/js/4950.bc182e62.chunk.js +1 -0
  144. rapidfireai/frontend/build/static/js/5042.d4f0c65a.chunk.js +2 -0
  145. rapidfireai/frontend/build/static/js/5042.d4f0c65a.chunk.js.LICENSE.txt +6 -0
  146. rapidfireai/frontend/build/static/js/5170.0065e96f.chunk.js +1 -0
  147. rapidfireai/frontend/build/static/js/5222.35c74a52.js +2 -0
  148. rapidfireai/frontend/build/static/js/5222.35c74a52.js.LICENSE.txt +10 -0
  149. rapidfireai/frontend/build/static/js/5223.3224f019.chunk.js +2 -0
  150. rapidfireai/frontend/build/static/js/5223.3224f019.chunk.js.LICENSE.txt +3 -0
  151. rapidfireai/frontend/build/static/js/5229.7dd42316.chunk.js +1 -0
  152. rapidfireai/frontend/build/static/js/5286.4c1ad26b.js +1 -0
  153. rapidfireai/frontend/build/static/js/5486.21cff711.chunk.js +1 -0
  154. rapidfireai/frontend/build/static/js/5526.7b368956.chunk.js +1 -0
  155. rapidfireai/frontend/build/static/js/5605.1ee4d87b.chunk.js +1 -0
  156. rapidfireai/frontend/build/static/js/5682.40b42d8b.chunk.js +1 -0
  157. rapidfireai/frontend/build/static/js/5794.9433d867.chunk.js +1 -0
  158. rapidfireai/frontend/build/static/js/5826.38a56e8c.chunk.js +2 -0
  159. rapidfireai/frontend/build/static/js/5826.38a56e8c.chunk.js.LICENSE.txt +1 -0
  160. rapidfireai/frontend/build/static/js/5862.50f42a0b.js +1 -0
  161. rapidfireai/frontend/build/static/js/5895.e26742f1.chunk.js +1 -0
  162. rapidfireai/frontend/build/static/js/5919.edd4a5cf.chunk.js +1 -0
  163. rapidfireai/frontend/build/static/js/598.a0e792ae.js +1 -0
  164. rapidfireai/frontend/build/static/js/6058.74162bf9.chunk.js +1 -0
  165. rapidfireai/frontend/build/static/js/618.06051134.chunk.js +2 -0
  166. rapidfireai/frontend/build/static/js/618.06051134.chunk.js.LICENSE.txt +21 -0
  167. rapidfireai/frontend/build/static/js/6335.9fca442d.chunk.js +1 -0
  168. rapidfireai/frontend/build/static/js/6336.e05e1154.chunk.js +1 -0
  169. rapidfireai/frontend/build/static/js/6343.2bcd28ff.chunk.js +1 -0
  170. rapidfireai/frontend/build/static/js/6363.a319b8f2.chunk.js +1 -0
  171. rapidfireai/frontend/build/static/js/6478.344abf25.chunk.js +1 -0
  172. rapidfireai/frontend/build/static/js/6504.1c004564.js +1 -0
  173. rapidfireai/frontend/build/static/js/6534.ec7e149b.chunk.js +1 -0
  174. rapidfireai/frontend/build/static/js/6715.55a5c19c.chunk.js +1 -0
  175. rapidfireai/frontend/build/static/js/6756.e6cb993c.chunk.js +2 -0
  176. rapidfireai/frontend/build/static/js/6756.e6cb993c.chunk.js.LICENSE.txt +10 -0
  177. rapidfireai/frontend/build/static/js/6762.acfde9fd.chunk.js +2 -0
  178. rapidfireai/frontend/build/static/js/6762.acfde9fd.chunk.js.LICENSE.txt +19 -0
  179. rapidfireai/frontend/build/static/js/6846.67103d0e.chunk.js +1 -0
  180. rapidfireai/frontend/build/static/js/6861.34cf0198.chunk.js +1 -0
  181. rapidfireai/frontend/build/static/js/6899.0eaf36a8.chunk.js +2 -0
  182. rapidfireai/frontend/build/static/js/6899.0eaf36a8.chunk.js.LICENSE.txt +5 -0
  183. rapidfireai/frontend/build/static/js/6933.8b564944.chunk.js +1 -0
  184. rapidfireai/frontend/build/static/js/699.d0437920.js +1 -0
  185. rapidfireai/frontend/build/static/js/7076.4182f63a.chunk.js +1 -0
  186. rapidfireai/frontend/build/static/js/7186.42ad86d5.chunk.js +1 -0
  187. rapidfireai/frontend/build/static/js/7248.a46635fd.js +1 -0
  188. rapidfireai/frontend/build/static/js/725.6b15a14a.chunk.js +1 -0
  189. rapidfireai/frontend/build/static/js/7266.3575539d.chunk.js +1 -0
  190. rapidfireai/frontend/build/static/js/7270.0a1e84fc.chunk.js +2 -0
  191. rapidfireai/frontend/build/static/js/7270.0a1e84fc.chunk.js.LICENSE.txt +6 -0
  192. rapidfireai/frontend/build/static/js/7367.7120474f.chunk.js +1 -0
  193. rapidfireai/frontend/build/static/js/7436.8e226055.js +1 -0
  194. rapidfireai/frontend/build/static/js/7504.ef223844.chunk.js +1 -0
  195. rapidfireai/frontend/build/static/js/7603.ee049fe3.chunk.js +1 -0
  196. rapidfireai/frontend/build/static/js/7670.2835b49a.chunk.js +2 -0
  197. rapidfireai/frontend/build/static/js/7670.2835b49a.chunk.js.LICENSE.txt +6 -0
  198. rapidfireai/frontend/build/static/js/7721.7390b3cc.chunk.js +1 -0
  199. rapidfireai/frontend/build/static/js/7731.5796cced.chunk.js +1 -0
  200. rapidfireai/frontend/build/static/js/775.660a5deb.chunk.js +2 -0
  201. rapidfireai/frontend/build/static/js/775.660a5deb.chunk.js.LICENSE.txt +6 -0
  202. rapidfireai/frontend/build/static/js/7832.7976a3e4.chunk.js +1 -0
  203. rapidfireai/frontend/build/static/js/7844.72cc2e81.chunk.js +1 -0
  204. rapidfireai/frontend/build/static/js/7948.48eab032.js +1 -0
  205. rapidfireai/frontend/build/static/js/7972.085079d4.chunk.js +2 -0
  206. rapidfireai/frontend/build/static/js/7972.085079d4.chunk.js.LICENSE.txt +6 -0
  207. rapidfireai/frontend/build/static/js/8017.a9e7dc5a.chunk.js +1 -0
  208. rapidfireai/frontend/build/static/js/8023.75f1f3df.js +2 -0
  209. rapidfireai/frontend/build/static/js/8023.75f1f3df.js.LICENSE.txt +41 -0
  210. rapidfireai/frontend/build/static/js/8123.b69db974.js +1 -0
  211. rapidfireai/frontend/build/static/js/813.065a87e5.chunk.js +1 -0
  212. rapidfireai/frontend/build/static/js/819.2056f122.chunk.js +2 -0
  213. rapidfireai/frontend/build/static/js/819.2056f122.chunk.js.LICENSE.txt +6 -0
  214. rapidfireai/frontend/build/static/js/8262.04bc17d1.chunk.js +1 -0
  215. rapidfireai/frontend/build/static/js/8300.75adcc4f.chunk.js +1 -0
  216. rapidfireai/frontend/build/static/js/8336.b1d3e764.chunk.js +1 -0
  217. rapidfireai/frontend/build/static/js/8365.26cf64ea.chunk.js +1 -0
  218. rapidfireai/frontend/build/static/js/8398.8bca8e0e.chunk.js +2 -0
  219. rapidfireai/frontend/build/static/js/8398.8bca8e0e.chunk.js.LICENSE.txt +6 -0
  220. rapidfireai/frontend/build/static/js/847.33ceed50.chunk.js +2 -0
  221. rapidfireai/frontend/build/static/js/847.33ceed50.chunk.js.LICENSE.txt +6 -0
  222. rapidfireai/frontend/build/static/js/8486.8ec852a7.chunk.js +1 -0
  223. rapidfireai/frontend/build/static/js/8497.19378265.chunk.js +1 -0
  224. rapidfireai/frontend/build/static/js/8541.4c55c9f4.chunk.js +1 -0
  225. rapidfireai/frontend/build/static/js/8690.e305a804.chunk.js +2 -0
  226. rapidfireai/frontend/build/static/js/8690.e305a804.chunk.js.LICENSE.txt +6 -0
  227. rapidfireai/frontend/build/static/js/8712.a9445fe6.chunk.js +1 -0
  228. rapidfireai/frontend/build/static/js/8763.61761e08.js +1 -0
  229. rapidfireai/frontend/build/static/js/8823.baf9bffd.chunk.js +2 -0
  230. rapidfireai/frontend/build/static/js/8823.baf9bffd.chunk.js.LICENSE.txt +6 -0
  231. rapidfireai/frontend/build/static/js/8867.767462b7.chunk.js +1 -0
  232. rapidfireai/frontend/build/static/js/8953.c0f88dea.chunk.js +1 -0
  233. rapidfireai/frontend/build/static/js/8960.357cb1eb.chunk.js +2 -0
  234. rapidfireai/frontend/build/static/js/8960.357cb1eb.chunk.js.LICENSE.txt +6 -0
  235. rapidfireai/frontend/build/static/js/9.f4492795.chunk.js +2 -0
  236. rapidfireai/frontend/build/static/js/9.f4492795.chunk.js.LICENSE.txt +12 -0
  237. rapidfireai/frontend/build/static/js/9079.88a8d2a3.js +1 -0
  238. rapidfireai/frontend/build/static/js/9082.37c40520.chunk.js +10 -0
  239. rapidfireai/frontend/build/static/js/9133.90ae330d.js +2 -0
  240. rapidfireai/frontend/build/static/js/9133.90ae330d.js.LICENSE.txt +8 -0
  241. rapidfireai/frontend/build/static/js/9151.1ac359d5.js +2 -0
  242. rapidfireai/frontend/build/static/js/9151.1ac359d5.js.LICENSE.txt +8 -0
  243. rapidfireai/frontend/build/static/js/9168.027bf2fd.chunk.js +1 -0
  244. rapidfireai/frontend/build/static/js/9194.9c5cc548.chunk.js +10 -0
  245. rapidfireai/frontend/build/static/js/9244.026f4aee.chunk.js +1 -0
  246. rapidfireai/frontend/build/static/js/936.2e02d037.js +2 -0
  247. rapidfireai/frontend/build/static/js/936.2e02d037.js.LICENSE.txt +6 -0
  248. rapidfireai/frontend/build/static/js/9369.7d1a0a1d.chunk.js +1 -0
  249. rapidfireai/frontend/build/static/js/9427.7c8442e7.chunk.js +1 -0
  250. rapidfireai/frontend/build/static/js/944.55948859.chunk.js +1 -0
  251. rapidfireai/frontend/build/static/js/9499.c53a82da.js +2 -0
  252. rapidfireai/frontend/build/static/js/9499.c53a82da.js.LICENSE.txt +62 -0
  253. rapidfireai/frontend/build/static/js/9531.3ce05781.chunk.js +1 -0
  254. rapidfireai/frontend/build/static/js/9547.92fac952.chunk.js +2 -0
  255. rapidfireai/frontend/build/static/js/9547.92fac952.chunk.js.LICENSE.txt +6 -0
  256. rapidfireai/frontend/build/static/js/9620.b6e973a7.chunk.js +1 -0
  257. rapidfireai/frontend/build/static/js/9645.6fddfa65.chunk.js +1 -0
  258. rapidfireai/frontend/build/static/js/9669.d38dda6d.js +1 -0
  259. rapidfireai/frontend/build/static/js/9682.41b6b807.chunk.js +1 -0
  260. rapidfireai/frontend/build/static/js/9720.19d5ae76.chunk.js +2 -0
  261. rapidfireai/frontend/build/static/js/9720.19d5ae76.chunk.js.LICENSE.txt +23 -0
  262. rapidfireai/frontend/build/static/js/9723.d3c7fe9e.js +1 -0
  263. rapidfireai/frontend/build/static/js/9780.02a27630.chunk.js +10 -0
  264. rapidfireai/frontend/build/static/js/9808.d0ca9674.chunk.js +2 -0
  265. rapidfireai/frontend/build/static/js/9808.d0ca9674.chunk.js.LICENSE.txt +6 -0
  266. rapidfireai/frontend/build/static/js/9815.b8db3c5d.js +1 -0
  267. rapidfireai/frontend/build/static/js/9886.2940b53a.chunk.js +1 -0
  268. rapidfireai/frontend/build/static/js/main~1f912138.fa9d03b1.js +1 -0
  269. rapidfireai/frontend/build/static/js/main~43dd7041.2e00860d.js +1 -0
  270. rapidfireai/frontend/build/static/js/main~84781932.68deffff.js +1 -0
  271. rapidfireai/frontend/build/static/media/404-overflow.fad9a31861b0afba6f921ebb8e769688.svg +32 -0
  272. rapidfireai/frontend/build/static/media/RapidFire_Square_Bug.27ceb48296314a4bc0d4.png +0 -0
  273. rapidfireai/frontend/build/static/media/chart-bar.0fd4a63680fba840a7b69fbf07969f79.svg +7 -0
  274. rapidfireai/frontend/build/static/media/chart-contour.0d4b306f2669f3ad25375568935e3ce3.svg +5 -0
  275. rapidfireai/frontend/build/static/media/chart-difference.16174216d6f3b7c24f40e3541fe0ca2c.svg +20 -0
  276. rapidfireai/frontend/build/static/media/chart-image.cc434c4dc50780966344e2385a15f8fe.svg +6 -0
  277. rapidfireai/frontend/build/static/media/chart-line.0adaa2036bb4eb5956db6d0c7e925a3d.svg +4 -0
  278. rapidfireai/frontend/build/static/media/chart-parallel.da7dedf539b2af4b654d377c679173e4.svg +7 -0
  279. rapidfireai/frontend/build/static/media/chart-scatter.69118d0023a6ff3973f7fa913834ac47.svg +9 -0
  280. rapidfireai/frontend/build/static/media/default-error.f246ddf367c6fbd67942e5a13382a7f1.svg +26 -0
  281. rapidfireai/frontend/build/static/media/fontawesome-webfont.1e59d2330b4c6deb84b3.ttf +0 -0
  282. rapidfireai/frontend/build/static/media/fontawesome-webfont.20fd1704ea223900efa9.woff2 +0 -0
  283. rapidfireai/frontend/build/static/media/fontawesome-webfont.8b43027f47b20503057d.eot +0 -0
  284. rapidfireai/frontend/build/static/media/fontawesome-webfont.c1e38fd9e0e74ba58f7a.svg +2671 -0
  285. rapidfireai/frontend/build/static/media/fontawesome-webfont.f691f37e57f04c152e23.woff +0 -0
  286. rapidfireai/frontend/build/static/media/icon-visible-fill.8d34cd35303828fdfc15154f5536e63b.svg +7 -0
  287. rapidfireai/frontend/build/static/media/no-experiments.0e4f4a114ef73e7d81c09474aba64b6c.svg +22 -0
  288. rapidfireai/frontend/build/static/media/parallel-chart-placeholder.234ef0c5b220ef2a5a6fa5bafff173f7.svg +16 -0
  289. rapidfireai/frontend/build/static/media/permission-denied-lock.16036747d57cd663d7df223781a447b2.svg +14 -0
  290. rapidfireai/frontend/build/static/media/promo-modal-content.e3b2c6c568ac192b9bec54b838b54850.svg +30 -0
  291. rapidfireai/frontend/build/static/media/registered-model-grey-ok.8274b58d39504c8d1b8c358aa1c9aa35.svg +23 -0
  292. rapidfireai/frontend/build/static/media/warning.290a3b14118933547965e91ea61c5a61.svg +3 -0
  293. rapidfireai/frontend/proxy_middleware.py +233 -0
  294. rapidfireai/frontend/server.py +25 -0
  295. rapidfireai/ml/__init__.py +0 -0
  296. rapidfireai/ml/callbacks.py +176 -0
  297. rapidfireai/ml/checkpoint_utils.py +540 -0
  298. rapidfireai/ml/trainer.py +309 -0
  299. rapidfireai/start.sh +634 -0
  300. rapidfireai/utils/__init__.py +0 -0
  301. rapidfireai/utils/automl_utils.py +51 -0
  302. rapidfireai/utils/constants.py +141 -0
  303. rapidfireai/utils/datapaths.py +69 -0
  304. rapidfireai/utils/exceptions.py +82 -0
  305. rapidfireai/utils/experiment_utils.py +370 -0
  306. rapidfireai/utils/logging.py +87 -0
  307. rapidfireai/utils/mlflow_manager.py +121 -0
  308. rapidfireai/utils/serialize.py +15 -0
  309. rapidfireai/utils/shm_manager.py +469 -0
  310. rapidfireai/utils/trainer_config.py +23 -0
  311. rapidfireai/utils/worker_manager.py +219 -0
  312. rapidfireai/version.py +6 -0
  313. rapidfireai-0.9.10.dist-info/METADATA +247 -0
  314. rapidfireai-0.9.10.dist-info/RECORD +318 -0
  315. rapidfireai-0.9.10.dist-info/entry_points.txt +2 -0
  316. rapidfireai-0.0.1.dist-info/METADATA +0 -37
  317. rapidfireai-0.0.1.dist-info/RECORD +0 -6
  318. {rapidfireai-0.0.1.dist-info → rapidfireai-0.9.10.dist-info}/WHEEL +0 -0
  319. {rapidfireai-0.0.1.dist-info → rapidfireai-0.9.10.dist-info}/licenses/LICENSE +0 -0
  320. {rapidfireai-0.0.1.dist-info → rapidfireai-0.9.10.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,540 @@
1
+ import copy
2
+ import json
3
+ import os
4
+ from typing import Callable, Optional, Union
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from peft import LoraConfig, PeftModel, get_peft_model, get_peft_model_state_dict, set_peft_model_state_dict
9
+ from transformers import AutoTokenizer
10
+ from trl import DPOTrainer, GRPOTrainer, SFTTrainer
11
+
12
+ from rapidfireai.utils.constants import SHMObjectType
13
+ from rapidfireai.utils.datapaths import DataPath
14
+ from rapidfireai.utils.shm_manager import SharedMemoryManager
15
+ from rapidfireai.utils.trainer_config import TrainerConfig
16
+
17
+
18
+ def move_tensors_to_device(obj, device: torch.device):
19
+ """Recursively move all tensors in a nested structure to device"""
20
+ if isinstance(obj, torch.Tensor):
21
+ return obj.to(device, non_blocking=True)
22
+ elif isinstance(obj, dict):
23
+ return {key: move_tensors_to_device(value, device) for key, value in obj.items()}
24
+ elif isinstance(obj, list):
25
+ return [move_tensors_to_device(item, device) for item in obj]
26
+ elif isinstance(obj, tuple):
27
+ return tuple(move_tensors_to_device(item, device) for item in obj)
28
+ else:
29
+ return obj
30
+
31
+
32
+ def move_tensors_to_cpu(obj):
33
+ """Recursively move all tensors in a nested structure to CPU"""
34
+ if isinstance(obj, torch.Tensor):
35
+ return obj.cpu().clone()
36
+ elif isinstance(obj, dict):
37
+ return {key: move_tensors_to_cpu(value) for key, value in obj.items()}
38
+ elif isinstance(obj, list):
39
+ return [move_tensors_to_cpu(item) for item in obj]
40
+ elif isinstance(obj, tuple):
41
+ return tuple(move_tensors_to_cpu(item) for item in obj)
42
+ else:
43
+ return obj
44
+
45
+
46
+ def ensure_gradient_compatibility(model, use_peft: bool = False):
47
+ """Ensure model parameters have proper gradient settings"""
48
+ model.train()
49
+ if use_peft:
50
+ model.base_model.eval()
51
+ for name, param in model.named_parameters():
52
+ if any(adapter_key in name.lower() for adapter_key in ["lora", "adapter", "modules_to_save"]):
53
+ param.requires_grad = True
54
+ else:
55
+ param.requires_grad = False
56
+ else:
57
+ for param in model.parameters():
58
+ param.requires_grad = True
59
+ for n, p in model.named_parameters():
60
+ if "reference" in n:
61
+ p.requires_grad = False
62
+ model.train()
63
+ torch.set_grad_enabled(True)
64
+
65
+ return model
66
+
67
+
68
+ def _configure_tokenizer(tokenizer: AutoTokenizer) -> None:
69
+ """Configure tokenizer with proper padding token."""
70
+ if tokenizer.pad_token is None:
71
+ if tokenizer.eos_token is not None:
72
+ tokenizer.pad_token = tokenizer.eos_token
73
+ tokenizer.pad_token_id = tokenizer.eos_token_id
74
+ else:
75
+ tokenizer.add_special_tokens({"pad_token": "[PAD]"})
76
+
77
+
78
+ def create_model_instance(
79
+ model_config: dict,
80
+ create_model_fn: Callable,
81
+ checkpoint_path: Optional[str] = None,
82
+ is_peft: bool = False,
83
+ device: Optional[str] = None,
84
+ ) -> tuple[nn.Module, AutoTokenizer]:
85
+ """Create a model instance from a model configuration"""
86
+ if device is not None:
87
+ model_config["model_kwargs"]["device_map"] = {"": device}
88
+ if checkpoint_path and not is_peft:
89
+ model_config["model_name"] = checkpoint_path
90
+
91
+ model_instance, tokenizer = create_model_fn(model_config)
92
+
93
+ if is_peft and checkpoint_path:
94
+ model_instance = PeftModel.from_pretrained(model_instance, checkpoint_path)
95
+ _configure_tokenizer(tokenizer)
96
+
97
+ return model_instance, tokenizer
98
+
99
+
100
+ def save_checkpoint_to_shared_memory(
101
+ trainer: Union[SFTTrainer, DPOTrainer, GRPOTrainer], trainer_config: TrainerConfig, shm_manager: SharedMemoryManager
102
+ ) -> None:
103
+ """Save checkpoint to shared memory"""
104
+ checkpoint = {}
105
+
106
+ if hasattr(trainer.model, "peft_config"):
107
+ peft_state_dict = get_peft_model_state_dict(
108
+ trainer.model,
109
+ adapter_name=trainer_config.config_leaf.get("training_args").get("model_adapter_name", "default"),
110
+ )
111
+ checkpoint["state_dict"] = {k: v.cpu().clone() for k, v in peft_state_dict.items()}
112
+ checkpoint["adapter_config"] = trainer.model.peft_config
113
+
114
+ if trainer.optimizer is not None:
115
+ checkpoint["optimizer_state"] = move_tensors_to_cpu(trainer.optimizer.state_dict())
116
+
117
+ if trainer.lr_scheduler is not None:
118
+ checkpoint["scheduler_state"] = move_tensors_to_cpu(trainer.lr_scheduler.state_dict())
119
+
120
+ if hasattr(trainer, "state"):
121
+ checkpoint["trainer_state"] = move_tensors_to_cpu(trainer.state.__dict__.copy())
122
+
123
+ if hasattr(trainer, "rng_state"):
124
+ checkpoint["rng_state"] = move_tensors_to_cpu(trainer.rng_state)
125
+
126
+ if hasattr(trainer.args, "__dict__"):
127
+ checkpoint["training_args"] = trainer.args.__dict__.copy()
128
+
129
+ if hasattr(trainer.model, "generation_config"):
130
+ checkpoint["generation_config"] = trainer.model.generation_config.to_dict()
131
+
132
+ if hasattr(trainer, "scaler"):
133
+ checkpoint["scaler"] = move_tensors_to_cpu(trainer.scaler.state_dict())
134
+
135
+ if hasattr(trainer.model, "config"):
136
+ config = trainer.model.config
137
+ if hasattr(config, "special_tokens_map"):
138
+ checkpoint["special_tokens_map"] = config.special_tokens_map
139
+ if hasattr(config, "tokenizer_config"):
140
+ checkpoint["tokenizer_config"] = config.tokenizer_config
141
+
142
+ shm_manager.save_model_object(trainer_config.run_id, SHMObjectType.CHECKPOINTS, checkpoint)
143
+
144
+
145
+ def load_checkpoint_from_shared_memory(
146
+ trainer_config: TrainerConfig, shm_manager: SharedMemoryManager, ref: bool = False, is_peft: bool = False
147
+ ) -> tuple[nn.Module, AutoTokenizer, dict]:
148
+ """Load checkpoint from shared memory"""
149
+ run_id = trainer_config.run_id
150
+ device = "cuda:0"
151
+ base_model = None
152
+ model_id = trainer_config.config_leaf.get("model_name")
153
+
154
+ if trainer_config.warm_started_from is not None and not shm_manager.model_exists(run_id):
155
+ shm_manager.create_warm_start_checkpoint(run_id, trainer_config.warm_started_from)
156
+
157
+ if is_peft:
158
+ if not shm_manager.model_exists(model_id):
159
+ base_model, tokenizer = create_model_instance(
160
+ trainer_config.config_leaf,
161
+ trainer_config.create_model_fn,
162
+ checkpoint_path=None,
163
+ is_peft=is_peft,
164
+ device=device,
165
+ )
166
+ if model_id is None:
167
+ model_id = base_model.config._name_or_path
168
+ save_model_to_shared_memory(
169
+ base_model, tokenizer, trainer_config, shm_manager, SHMObjectType.BASE_MODEL, model_id
170
+ )
171
+ else:
172
+ base_model, tokenizer = load_model_from_shared_memory(
173
+ trainer_config, shm_manager, SHMObjectType.BASE_MODEL, model_id
174
+ )
175
+
176
+ if base_model == "" or (not is_peft and not shm_manager.model_exists(run_id)):
177
+ base_model, tokenizer = create_model_instance(
178
+ trainer_config.config_leaf,
179
+ trainer_config.create_model_fn,
180
+ checkpoint_path=None,
181
+ is_peft=is_peft,
182
+ device=device,
183
+ )
184
+
185
+ model = base_model
186
+ peft_config = LoraConfig(**trainer_config.config_leaf.get("peft_params", {}))
187
+ if is_peft:
188
+ model = get_peft_model(model, peft_config)
189
+
190
+ # Load weights from shared memory
191
+ if trainer_config.completed_steps > 0 or trainer_config.warm_started_from is not None:
192
+ checkpoint = shm_manager.load_model_object(run_id, SHMObjectType.CHECKPOINTS)
193
+
194
+ if "adapter_config" in checkpoint:
195
+ if trainer_config.config_leaf.get("trainer_type") == "DPO" and is_peft:
196
+ reference_state_dict = shm_manager.load_model_object(
197
+ trainer_config.run_id, SHMObjectType.REF_STATE_DICT
198
+ )
199
+ reference_state_dict = move_tensors_to_device(reference_state_dict, device)
200
+ model.add_adapter(
201
+ trainer_config.config_leaf.get("training_args", {}).get("ref_adapter_name"), peft_config
202
+ )
203
+ model.set_adapter(
204
+ trainer_config.config_leaf.get("training_args", {}).get("model_adapter_name", "default")
205
+ )
206
+ set_peft_model_state_dict(
207
+ model,
208
+ reference_state_dict,
209
+ adapter_name=trainer_config.config_leaf.get("training_args").get("ref_adapter_name"),
210
+ )
211
+ model.set_adapter(trainer_config.config_leaf.get("training_args").get("model_adapter_name", "default"))
212
+
213
+ if checkpoint.get("state_dict"):
214
+ state_dict = {k: v.to(device) for k, v in checkpoint["state_dict"].items()}
215
+ if is_peft:
216
+ set_peft_model_state_dict(
217
+ model,
218
+ state_dict,
219
+ adapter_name=trainer_config.config_leaf.get("training_args", {}).get(
220
+ "model_adapter_name", "default"
221
+ ),
222
+ )
223
+ if trainer_config.config_leaf.get("trainer_type") == "DPO" and is_peft:
224
+ model.set_adapter(
225
+ trainer_config.config_leaf.get("training_args", {}).get("model_adapter_name", "default")
226
+ )
227
+ else:
228
+ model.load_state_dict(state_dict)
229
+
230
+ elif not is_peft:
231
+ model, tokenizer = load_model_from_shared_memory(
232
+ trainer_config, shm_manager, SHMObjectType.FULL_MODEL, trainer_config.run_id
233
+ )
234
+
235
+ return model, tokenizer
236
+
237
+
238
+ def load_model_from_shared_memory(
239
+ trainer_config: TrainerConfig, shm_manager: SharedMemoryManager, model_object_type: SHMObjectType, model_id: str
240
+ ) -> tuple[nn.Module, AutoTokenizer]:
241
+ """Load model from shared memory"""
242
+ model_data = shm_manager.load_model_object(model_id, model_object_type)
243
+ model = copy.deepcopy(model_data[model_object_type])
244
+ tokenizer = model_data["tokenizer"]
245
+ bnb_modules = move_tensors_to_device(model_data["bnb_modules"], device="cuda:0")
246
+ model = get_model_to_device(model, bnb_modules, device="cuda:0")
247
+ return model, tokenizer
248
+
249
+
250
+ def save_model_to_shared_memory(
251
+ model: Union[nn.Module, str],
252
+ tokenizer: AutoTokenizer,
253
+ trainer_config: TrainerConfig,
254
+ shm_manager: SharedMemoryManager,
255
+ model_type: str,
256
+ model_id: str,
257
+ ) -> None:
258
+ """Save model to shared memory"""
259
+ if model_type != SHMObjectType.FULL_MODEL and shm_manager.model_exists(model_id):
260
+ return
261
+ model_cpu = model.cpu()
262
+ model_data = {model_type: model_cpu, "tokenizer": tokenizer}
263
+ shm_manager.save_model_object(model_id, model_type, model_data)
264
+
265
+
266
+ def load_or_create_ref_model(
267
+ model_instance,
268
+ trainer_config: TrainerConfig,
269
+ device: str,
270
+ use_shared_memory: bool,
271
+ shm_manager: SharedMemoryManager,
272
+ ) -> Optional[nn.Module]:
273
+ """Load or create reference model for DPO training based on configuration"""
274
+ config_leaf = trainer_config.config_leaf
275
+ device = "cuda:0"
276
+ ref_model_name = trainer_config.config_leaf.get("ref_model_config", {}).get("model_name", None)
277
+ model_id = trainer_config.config_leaf.get("model_name")
278
+ ref_model_id = "ref_" + (trainer_config.config_leaf.get("ref_model_config", {}).get("model_name") or model_id)
279
+ if use_shared_memory and shm_manager.model_exists(ref_model_id):
280
+ ref_model_instance, _ = load_model_from_shared_memory(
281
+ trainer_config, shm_manager, SHMObjectType.REF_FULL_MODEL, ref_model_id
282
+ )
283
+ else:
284
+ if ref_model_name is not None:
285
+ ref_model_instance, _ = create_model_instance(
286
+ config_leaf.get("ref_model_config"), trainer_config.create_model_fn, device=device
287
+ )
288
+ elif trainer_config.completed_steps == 0:
289
+ ref_model_instance = copy.deepcopy(model_instance)
290
+ save_model_to_shared_memory(
291
+ ref_model_instance, None, trainer_config, shm_manager, SHMObjectType.REF_FULL_MODEL, ref_model_id
292
+ )
293
+
294
+ return ref_model_instance
295
+
296
+
297
+ def get_model_to_device(model, bnb_modules, device="cuda:0"):
298
+ """Move model from shared memory to specified device with proper BitsAndBytes restoration"""
299
+ for name, param in model.named_parameters():
300
+ if param.data is not None:
301
+ param.data = move_tensors_to_device(param.data, device)
302
+
303
+ for name, buffer in model.named_buffers():
304
+ if isinstance(buffer, torch.Tensor) and buffer is not None:
305
+ parent_module = model
306
+ attr_path = name.split(".")
307
+
308
+ for attr in attr_path[:-1]:
309
+ parent_module = getattr(parent_module, attr)
310
+
311
+ device_buffer = move_tensors_to_device(buffer, device)
312
+ setattr(parent_module, attr_path[-1], device_buffer)
313
+
314
+ for name, module in model.named_modules():
315
+ if not hasattr(module, "weight"):
316
+ continue
317
+
318
+ try:
319
+ import bitsandbytes as bnb
320
+
321
+ bnb_layer_types = [bnb.nn.Linear4bit, bnb.nn.LinearFP4, bnb.nn.LinearNF4, bnb.nn.Params4bit]
322
+ except ImportError:
323
+ continue
324
+
325
+ is_bnb_layer = any(isinstance(module, layer_type) for layer_type in bnb_layer_types)
326
+
327
+ if is_bnb_layer and name in bnb_modules:
328
+ bnb_attrs = bnb_modules[name]
329
+ weight = module.weight
330
+
331
+ if hasattr(weight, "data") and weight.data is not None:
332
+ weight.data = move_tensors_to_device(weight.data, device)
333
+
334
+ if "quant_state_data" in bnb_attrs:
335
+ if not hasattr(weight, "quant_state") or weight.quant_state is None:
336
+ from bitsandbytes.functional import QuantState
337
+
338
+ quant_data = bnb_attrs["quant_state_data"]
339
+
340
+ weight.quant_state = QuantState(absmax=quant_data["absmax"], code=quant_data["code"])
341
+
342
+ quant_data = bnb_attrs["quant_state_data"]
343
+
344
+ for attr, value in quant_data.items():
345
+ if isinstance(value, torch.Tensor):
346
+ value = move_tensors_to_device(value, device)
347
+ setattr(weight.quant_state, attr, value)
348
+ if "state2_data" in bnb_attrs:
349
+ for attr, value in bnb_attrs["state2_data"].items():
350
+ if isinstance(value, torch.Tensor):
351
+ value = move_tensors_to_device(value, device)
352
+ setattr(weight.quant_state.state2, attr, value)
353
+
354
+ for attr, value in bnb_attrs.items():
355
+ if attr not in ["quant_state_data", "quant_state_class", "weight_class", "state2_data"]:
356
+ if isinstance(value, torch.Tensor):
357
+ value = move_tensors_to_device(value, device)
358
+ setattr(weight, attr, value)
359
+
360
+ elif hasattr(module, "weight") and hasattr(module.weight, "data") and module.weight.data is not None:
361
+ if name not in bnb_modules:
362
+ module.weight.data = move_tensors_to_device(module.weight.data, device)
363
+
364
+ model = model.to(device)
365
+ return model
366
+
367
+
368
+ def restore_trainer_from_shared_memory(
369
+ trainer: Union[SFTTrainer, DPOTrainer, GRPOTrainer],
370
+ trainer_config: TrainerConfig,
371
+ shm_manager: SharedMemoryManager,
372
+ ) -> Union[SFTTrainer, DPOTrainer, GRPOTrainer]:
373
+ """Restore complete training state to trainer"""
374
+ try:
375
+ device = next(trainer.model.parameters()).device
376
+
377
+ if shm_manager.model_exists(trainer_config.run_id):
378
+ training_state = shm_manager.load_model_object(trainer_config.run_id, SHMObjectType.CHECKPOINTS)
379
+ else:
380
+ raise ValueError(f"Training state for run {trainer_config.run_id} not found in shared memory")
381
+
382
+ if training_state.get("trainer_state") is not None and hasattr(trainer, "state"):
383
+ trainer_state_dict = training_state["trainer_state"]
384
+ device_trainer_state = move_tensors_to_device(trainer_state_dict, device)
385
+ for key, value in device_trainer_state.items():
386
+ if hasattr(trainer.state, key):
387
+ setattr(trainer.state, key, value)
388
+
389
+ if training_state.get("optimizer_state") is not None and trainer.optimizer is not None:
390
+ optimizer_state = training_state["optimizer_state"]
391
+ device_optimizer_state = move_tensors_to_device(optimizer_state, device)
392
+ trainer.optimizer.load_state_dict(device_optimizer_state)
393
+
394
+ if training_state.get("scheduler_state") is not None and trainer.lr_scheduler is not None:
395
+ scheduler_state = training_state["scheduler_state"]
396
+ device_scheduler_state = move_tensors_to_device(scheduler_state, device)
397
+ trainer.lr_scheduler.load_state_dict(device_scheduler_state)
398
+
399
+ if hasattr(trainer.state, "global_step"):
400
+ trainer.lr_scheduler._step_count = trainer.state.global_step + 1
401
+
402
+ if training_state.get("rng_state") is not None:
403
+ rng_state = training_state["rng_state"]
404
+ device_rng_state = move_tensors_to_device(rng_state, device)
405
+ trainer.rng_state = device_rng_state
406
+
407
+ if training_state.get("generation_config") is not None and hasattr(trainer.model, "generation_config"):
408
+ trainer.model.generation_config = type(trainer.model.generation_config)(
409
+ **training_state["generation_config"]
410
+ )
411
+ if training_state.get("scaler") is not None:
412
+ trainer.scaler.load_state_dict(training_state["scaler"])
413
+
414
+ except Exception as e:
415
+ print(f"Warning: Error restoring training state: {e}")
416
+
417
+ return trainer
418
+
419
+
420
+ def save_checkpoint_to_disk(
421
+ trainer: Union[SFTTrainer, DPOTrainer, GRPOTrainer],
422
+ trainer_config: TrainerConfig,
423
+ first: bool = False,
424
+ last: bool = False,
425
+ completed_steps: int = 0,
426
+ ) -> None:
427
+ base_run_path = DataPath.base_run_path(trainer_config.run_id)
428
+ if first:
429
+ checkpoint_path = DataPath.initial_checkpoint_path(base_run_path)
430
+ elif last:
431
+ checkpoint_path = DataPath.final_checkpoint_path(base_run_path)
432
+ else:
433
+ checkpoint_path = DataPath.intermediate_checkpoint_path(base_run_path) / "checkpoint"
434
+ os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
435
+ os.makedirs(checkpoint_path, exist_ok=True)
436
+
437
+ trainer.model.save_pretrained(checkpoint_path)
438
+
439
+ trainer_state_path = checkpoint_path / "trainer_state.json"
440
+ trainer_state_dict = trainer.state.__dict__.copy()
441
+ with open(trainer_state_path, "w") as f:
442
+ json.dump(trainer_state_dict, f, indent=2)
443
+
444
+ if trainer.optimizer is not None:
445
+ optimizer_path = checkpoint_path / "optimizer.pt"
446
+ optimizer_state = move_tensors_to_cpu(trainer.optimizer.state_dict())
447
+ torch.save(optimizer_state, optimizer_path)
448
+
449
+ if trainer.lr_scheduler is not None:
450
+ scheduler_path = checkpoint_path / "scheduler.pt"
451
+ scheduler_state = move_tensors_to_cpu(trainer.lr_scheduler.state_dict())
452
+ torch.save(scheduler_state, scheduler_path)
453
+
454
+ if hasattr(trainer, "rng_state"):
455
+ rng_state_path = checkpoint_path / "rng_state.pth"
456
+ torch.save(trainer.rng_state, rng_state_path)
457
+
458
+
459
+ def load_checkpoint_from_disk(
460
+ trainer_config: TrainerConfig, ref: bool = False, is_peft: bool = False
461
+ ) -> tuple[nn.Module, AutoTokenizer, dict]:
462
+ """Load checkpoint from disk"""
463
+ device = "cuda:0"
464
+ checkpoint_path = None
465
+ if trainer_config.warm_started_from is not None and trainer_config.completed_steps == 0:
466
+ base_run_path = DataPath.base_run_path(trainer_config.warm_started_from)
467
+ checkpoint_path = DataPath.intermediate_checkpoint_path(base_run_path) / "checkpoint"
468
+ elif trainer_config.completed_steps > 0:
469
+ base_run_path = DataPath.base_run_path(trainer_config.run_id)
470
+ checkpoint_path = DataPath.intermediate_checkpoint_path(base_run_path) / "checkpoint"
471
+
472
+ model_instance, tokenizer = create_model_instance(
473
+ trainer_config.config_leaf, trainer_config.create_model_fn, checkpoint_path, is_peft=is_peft, device=device
474
+ )
475
+ if is_peft and checkpoint_path is None:
476
+ model_instance = get_peft_model(model_instance, LoraConfig(**trainer_config.config_leaf.get("peft_params", {})))
477
+
478
+ if ref:
479
+ model_instance, tokenizer = create_model_instance(
480
+ trainer_config.config_leaf.get("ref_model_config", {}), trainer_config.create_model_fn, device=device
481
+ )
482
+
483
+ return model_instance, tokenizer
484
+
485
+
486
+ def restore_trainer_from_disk(
487
+ trainer: Union[SFTTrainer, DPOTrainer, GRPOTrainer], trainer_config: TrainerConfig
488
+ ) -> Union[SFTTrainer, DPOTrainer, GRPOTrainer]:
489
+ """Restore trainer from disk with proper state accumulation"""
490
+ base_run_path = DataPath.base_run_path(trainer_config.run_id)
491
+ checkpoint_path = DataPath.intermediate_checkpoint_path(base_run_path) / "checkpoint"
492
+ device = "cuda:0"
493
+
494
+ trainer_state_path = checkpoint_path / "trainer_state.json"
495
+ if trainer_state_path.exists():
496
+ with open(trainer_state_path) as f:
497
+ trainer_state_dict = json.load(f)
498
+
499
+ for key, value in trainer_state_dict.items():
500
+ if hasattr(trainer.state, key):
501
+ setattr(trainer.state, key, value)
502
+
503
+ optimizer_path = checkpoint_path / "optimizer.pt"
504
+ if optimizer_path.exists() and trainer.optimizer is not None:
505
+ optimizer_state = torch.load(optimizer_path, map_location=device)
506
+ model_device = next(trainer.model.parameters()).device
507
+ optimizer_state = move_tensors_to_device(optimizer_state, model_device)
508
+ trainer.optimizer.load_state_dict(optimizer_state)
509
+
510
+ lr_scheduler_path = checkpoint_path / "scheduler.pt"
511
+ if lr_scheduler_path.exists() and trainer.lr_scheduler is not None:
512
+ lr_scheduler_state = torch.load(lr_scheduler_path, map_location=device)
513
+ model_device = next(trainer.model.parameters()).device
514
+ lr_scheduler_state = move_tensors_to_device(lr_scheduler_state, model_device)
515
+ trainer.lr_scheduler.load_state_dict(lr_scheduler_state)
516
+
517
+ if hasattr(trainer.state, "global_step") and trainer.lr_scheduler is not None:
518
+ trainer.lr_scheduler._step_count = trainer.state.global_step + 1
519
+
520
+ rng_state_path = checkpoint_path / "rng_state.pth"
521
+ if rng_state_path.exists():
522
+ rng_state = torch.load(rng_state_path, map_location=device, weights_only=False)
523
+ model_device = next(trainer.model.parameters()).device
524
+ rng_state = move_tensors_to_device(rng_state, model_device)
525
+ trainer.rng_state = rng_state
526
+
527
+ return trainer
528
+
529
+
530
+ def save_ref_model_to_disk(model_instance: nn.Module, trainer_config: TrainerConfig, ref: bool = False) -> None:
531
+ """Save reference model to disk"""
532
+ base_run_path = DataPath.base_run_path(trainer_config.run_id)
533
+ ref_model_path = DataPath.ref_model_path(base_run_path)
534
+ os.makedirs(ref_model_path, exist_ok=True)
535
+
536
+ if hasattr(model_instance, "peft_config"):
537
+ peft_state_dict = get_peft_model_state_dict(model_instance)
538
+ torch.save(peft_state_dict, ref_model_path / "pytorch_model.bin")
539
+ else:
540
+ torch.save(model_instance.state_dict(), ref_model_path / "pytorch_model.bin")