rapidfireai 0.0.1__py3-none-any.whl → 0.9.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rapidfireai might be problematic. Click here for more details.

Files changed (320) hide show
  1. rapidfireai/__init__.py +11 -5
  2. rapidfireai/automl/__init__.py +20 -0
  3. rapidfireai/automl/base.py +48 -0
  4. rapidfireai/automl/datatypes.py +42 -0
  5. rapidfireai/automl/grid_search.py +125 -0
  6. rapidfireai/automl/model_config.py +102 -0
  7. rapidfireai/automl/random_search.py +145 -0
  8. rapidfireai/backend/__init__.py +0 -0
  9. rapidfireai/backend/chunks.py +63 -0
  10. rapidfireai/backend/controller.py +637 -0
  11. rapidfireai/backend/scheduler.py +137 -0
  12. rapidfireai/backend/worker.py +272 -0
  13. rapidfireai/cli.py +380 -0
  14. rapidfireai/db/__init__.py +0 -0
  15. rapidfireai/db/db_interface.py +135 -0
  16. rapidfireai/db/rf_db.py +694 -0
  17. rapidfireai/db/tables.sql +64 -0
  18. rapidfireai/dispatcher/dispatcher.py +391 -0
  19. rapidfireai/dispatcher/gunicorn.conf.py +25 -0
  20. rapidfireai/experiment.py +168 -0
  21. rapidfireai/frontend/build/asset-manifest.json +276 -0
  22. rapidfireai/frontend/build/favicon.ico +0 -0
  23. rapidfireai/frontend/build/index.html +1 -0
  24. rapidfireai/frontend/build/manifest.json +15 -0
  25. rapidfireai/frontend/build/pdf.worker.js +1 -0
  26. rapidfireai/frontend/build/report.html +39 -0
  27. rapidfireai/frontend/build/static/css/1482.3b7bf531.chunk.css +1 -0
  28. rapidfireai/frontend/build/static/css/2730.3f8937ff.chunk.css +1 -0
  29. rapidfireai/frontend/build/static/css/318.0def90a7.css +7 -0
  30. rapidfireai/frontend/build/static/css/4762.9b7b71f7.chunk.css +1 -0
  31. rapidfireai/frontend/build/static/css/4950.487ecc8b.chunk.css +1 -0
  32. rapidfireai/frontend/build/static/css/5170.2574ce9d.chunk.css +1 -0
  33. rapidfireai/frontend/build/static/css/6121.4d541986.chunk.css +1 -0
  34. rapidfireai/frontend/build/static/css/6343.dd6979f2.chunk.css +1 -0
  35. rapidfireai/frontend/build/static/css/6534.433c213f.chunk.css +1 -0
  36. rapidfireai/frontend/build/static/css/6920.ffac4b2a.css +2 -0
  37. rapidfireai/frontend/build/static/css/7246.bf2f0c87.css +9 -0
  38. rapidfireai/frontend/build/static/css/7367.dd6979f2.chunk.css +1 -0
  39. rapidfireai/frontend/build/static/css/8690.05d081e5.chunk.css +1 -0
  40. rapidfireai/frontend/build/static/css/9531.d0910d3c.chunk.css +1 -0
  41. rapidfireai/frontend/build/static/css/9780.363e4943.chunk.css +1 -0
  42. rapidfireai/frontend/build/static/css/main~d91a9049.c0be472c.css +1 -0
  43. rapidfireai/frontend/build/static/js/1000.e5ed264b.chunk.js +1 -0
  44. rapidfireai/frontend/build/static/js/1012.ac98ab59.chunk.js +1 -0
  45. rapidfireai/frontend/build/static/js/1079.6c13ac0d.js +1 -0
  46. rapidfireai/frontend/build/static/js/110.9059f3b8.chunk.js +1 -0
  47. rapidfireai/frontend/build/static/js/1142.872d0010.chunk.js +1 -0
  48. rapidfireai/frontend/build/static/js/1167.9a6da14c.chunk.js +1 -0
  49. rapidfireai/frontend/build/static/js/1248.60890b4f.chunk.js +1 -0
  50. rapidfireai/frontend/build/static/js/1262.83dc7673.chunk.js +1 -0
  51. rapidfireai/frontend/build/static/js/1273.56da3e13.chunk.js +2 -0
  52. rapidfireai/frontend/build/static/js/1273.56da3e13.chunk.js.LICENSE.txt +9 -0
  53. rapidfireai/frontend/build/static/js/1303.7d19305c.chunk.js +1 -0
  54. rapidfireai/frontend/build/static/js/1351.45076ff3.chunk.js +1 -0
  55. rapidfireai/frontend/build/static/js/1355.b896a592.js +1 -0
  56. rapidfireai/frontend/build/static/js/1357.02c46a02.chunk.js +1 -0
  57. rapidfireai/frontend/build/static/js/1470.c51d60c6.chunk.js +1 -0
  58. rapidfireai/frontend/build/static/js/1482.23b74f50.chunk.js +1 -0
  59. rapidfireai/frontend/build/static/js/1500.19799d8d.chunk.js +1 -0
  60. rapidfireai/frontend/build/static/js/1648.d3b9edc7.chunk.js +1 -0
  61. rapidfireai/frontend/build/static/js/1860.7d96e3f9.chunk.js +1 -0
  62. rapidfireai/frontend/build/static/js/1909.5b1d9ff4.chunk.js +1 -0
  63. rapidfireai/frontend/build/static/js/1928.44245110.chunk.js +2 -0
  64. rapidfireai/frontend/build/static/js/1928.44245110.chunk.js.LICENSE.txt +11 -0
  65. rapidfireai/frontend/build/static/js/1933.deba26ca.chunk.js +1 -0
  66. rapidfireai/frontend/build/static/js/21.aac92802.chunk.js +1 -0
  67. rapidfireai/frontend/build/static/js/2103.0ca12071.chunk.js +1 -0
  68. rapidfireai/frontend/build/static/js/2258.b3b8fab4.chunk.js +1 -0
  69. rapidfireai/frontend/build/static/js/2289.9ad51e87.chunk.js +1 -0
  70. rapidfireai/frontend/build/static/js/2323.7dd927d7.js +2 -0
  71. rapidfireai/frontend/build/static/js/2323.7dd927d7.js.LICENSE.txt +1 -0
  72. rapidfireai/frontend/build/static/js/2346.ed99ca72.chunk.js +1 -0
  73. rapidfireai/frontend/build/static/js/2386.0a660834.chunk.js +1 -0
  74. rapidfireai/frontend/build/static/js/2402.465048f9.chunk.js +1 -0
  75. rapidfireai/frontend/build/static/js/243.5a83bbca.chunk.js +1 -0
  76. rapidfireai/frontend/build/static/js/2589.68571e16.js +1 -0
  77. rapidfireai/frontend/build/static/js/2647.65092bab.chunk.js +1 -0
  78. rapidfireai/frontend/build/static/js/2691.65d4a4e7.js +1 -0
  79. rapidfireai/frontend/build/static/js/2730.b38dd6f3.chunk.js +1 -0
  80. rapidfireai/frontend/build/static/js/2746.ef752da4.chunk.js +1 -0
  81. rapidfireai/frontend/build/static/js/2779.580d4491.chunk.js +1 -0
  82. rapidfireai/frontend/build/static/js/2799.fe5993b2.chunk.js +1 -0
  83. rapidfireai/frontend/build/static/js/2844.9708db79.chunk.js +2 -0
  84. rapidfireai/frontend/build/static/js/2844.9708db79.chunk.js.LICENSE.txt +21 -0
  85. rapidfireai/frontend/build/static/js/2901.ee0c606b.chunk.js +1 -0
  86. rapidfireai/frontend/build/static/js/2932.7cc0689b.chunk.js +2 -0
  87. rapidfireai/frontend/build/static/js/2932.7cc0689b.chunk.js.LICENSE.txt +6 -0
  88. rapidfireai/frontend/build/static/js/2956.a393c8cc.chunk.js +1 -0
  89. rapidfireai/frontend/build/static/js/2972.679bed05.chunk.js +1 -0
  90. rapidfireai/frontend/build/static/js/2985.7e51cdfa.chunk.js +2 -0
  91. rapidfireai/frontend/build/static/js/2985.7e51cdfa.chunk.js.LICENSE.txt +51 -0
  92. rapidfireai/frontend/build/static/js/3093.488df653.js +1 -0
  93. rapidfireai/frontend/build/static/js/3145.66ee61b9.js +1 -0
  94. rapidfireai/frontend/build/static/js/3170.a22f966a.chunk.js +2 -0
  95. rapidfireai/frontend/build/static/js/3170.a22f966a.chunk.js.LICENSE.txt +21 -0
  96. rapidfireai/frontend/build/static/js/3307.f6fb258c.chunk.js +1 -0
  97. rapidfireai/frontend/build/static/js/3325.d5b03d65.js +1 -0
  98. rapidfireai/frontend/build/static/js/3334.2d6704df.chunk.js +2 -0
  99. rapidfireai/frontend/build/static/js/3334.2d6704df.chunk.js.LICENSE.txt +6 -0
  100. rapidfireai/frontend/build/static/js/3387.bb8edad3.chunk.js +1 -0
  101. rapidfireai/frontend/build/static/js/3448.438e6579.chunk.js +1 -0
  102. rapidfireai/frontend/build/static/js/3460.735eea87.chunk.js +1 -0
  103. rapidfireai/frontend/build/static/js/3505.7fd3921a.js +2 -0
  104. rapidfireai/frontend/build/static/js/3505.7fd3921a.js.LICENSE.txt +9 -0
  105. rapidfireai/frontend/build/static/js/3510.cd167a00.js +2 -0
  106. rapidfireai/frontend/build/static/js/3510.cd167a00.js.LICENSE.txt +18 -0
  107. rapidfireai/frontend/build/static/js/3563.cc828e19.chunk.js +1 -0
  108. rapidfireai/frontend/build/static/js/359.08960b84.chunk.js +2 -0
  109. rapidfireai/frontend/build/static/js/359.08960b84.chunk.js.LICENSE.txt +4 -0
  110. rapidfireai/frontend/build/static/js/3608.403b4b79.chunk.js +1 -0
  111. rapidfireai/frontend/build/static/js/3652.cb8add7f.js +1 -0
  112. rapidfireai/frontend/build/static/js/3775.5230b157.chunk.js +1 -0
  113. rapidfireai/frontend/build/static/js/3817.53555d18.js +2 -0
  114. rapidfireai/frontend/build/static/js/3817.53555d18.js.LICENSE.txt +18 -0
  115. rapidfireai/frontend/build/static/js/3835.d9946ff9.chunk.js +1 -0
  116. rapidfireai/frontend/build/static/js/3964.874f0297.chunk.js +1 -0
  117. rapidfireai/frontend/build/static/js/3968.275cbc3d.chunk.js +1 -0
  118. rapidfireai/frontend/build/static/js/3999.765cbd82.chunk.js +1 -0
  119. rapidfireai/frontend/build/static/js/4020.4452c046.chunk.js +1 -0
  120. rapidfireai/frontend/build/static/js/4138.2f6f6d9f.js +1 -0
  121. rapidfireai/frontend/build/static/js/4160.f424554c.js +1 -0
  122. rapidfireai/frontend/build/static/js/4180.50cea095.chunk.js +1 -0
  123. rapidfireai/frontend/build/static/js/4221.b0bba3f5.chunk.js +1 -0
  124. rapidfireai/frontend/build/static/js/4250.5bb49278.chunk.js +1 -0
  125. rapidfireai/frontend/build/static/js/4297.15777d8f.chunk.js +1 -0
  126. rapidfireai/frontend/build/static/js/4349.c965f2de.js +2 -0
  127. rapidfireai/frontend/build/static/js/4349.c965f2de.js.LICENSE.txt +1 -0
  128. rapidfireai/frontend/build/static/js/4484.4cbe5e7f.js +2 -0
  129. rapidfireai/frontend/build/static/js/4484.4cbe5e7f.js.LICENSE.txt +10 -0
  130. rapidfireai/frontend/build/static/js/4578.a8124588.js +1 -0
  131. rapidfireai/frontend/build/static/js/4596.89a97480.js +1 -0
  132. rapidfireai/frontend/build/static/js/4748.566f435a.chunk.js +1 -0
  133. rapidfireai/frontend/build/static/js/4762.928e8a90.chunk.js +1 -0
  134. rapidfireai/frontend/build/static/js/4768.7945be63.js +2 -0
  135. rapidfireai/frontend/build/static/js/4768.7945be63.js.LICENSE.txt +1 -0
  136. rapidfireai/frontend/build/static/js/4804.26b50dd4.chunk.js +1 -0
  137. rapidfireai/frontend/build/static/js/4850.62390a45.chunk.js +1 -0
  138. rapidfireai/frontend/build/static/js/4862.a0ccb221.chunk.js +1 -0
  139. rapidfireai/frontend/build/static/js/491.5dc8ed40.chunk.js +1 -0
  140. rapidfireai/frontend/build/static/js/492.9262f038.chunk.js +2 -0
  141. rapidfireai/frontend/build/static/js/492.9262f038.chunk.js.LICENSE.txt +6 -0
  142. rapidfireai/frontend/build/static/js/4943.6d345fd3.chunk.js +1 -0
  143. rapidfireai/frontend/build/static/js/4950.bc182e62.chunk.js +1 -0
  144. rapidfireai/frontend/build/static/js/5042.d4f0c65a.chunk.js +2 -0
  145. rapidfireai/frontend/build/static/js/5042.d4f0c65a.chunk.js.LICENSE.txt +6 -0
  146. rapidfireai/frontend/build/static/js/5170.0065e96f.chunk.js +1 -0
  147. rapidfireai/frontend/build/static/js/5222.35c74a52.js +2 -0
  148. rapidfireai/frontend/build/static/js/5222.35c74a52.js.LICENSE.txt +10 -0
  149. rapidfireai/frontend/build/static/js/5223.3224f019.chunk.js +2 -0
  150. rapidfireai/frontend/build/static/js/5223.3224f019.chunk.js.LICENSE.txt +3 -0
  151. rapidfireai/frontend/build/static/js/5229.7dd42316.chunk.js +1 -0
  152. rapidfireai/frontend/build/static/js/5286.4c1ad26b.js +1 -0
  153. rapidfireai/frontend/build/static/js/5486.21cff711.chunk.js +1 -0
  154. rapidfireai/frontend/build/static/js/5526.7b368956.chunk.js +1 -0
  155. rapidfireai/frontend/build/static/js/5605.1ee4d87b.chunk.js +1 -0
  156. rapidfireai/frontend/build/static/js/5682.40b42d8b.chunk.js +1 -0
  157. rapidfireai/frontend/build/static/js/5794.9433d867.chunk.js +1 -0
  158. rapidfireai/frontend/build/static/js/5826.38a56e8c.chunk.js +2 -0
  159. rapidfireai/frontend/build/static/js/5826.38a56e8c.chunk.js.LICENSE.txt +1 -0
  160. rapidfireai/frontend/build/static/js/5862.50f42a0b.js +1 -0
  161. rapidfireai/frontend/build/static/js/5895.e26742f1.chunk.js +1 -0
  162. rapidfireai/frontend/build/static/js/5919.edd4a5cf.chunk.js +1 -0
  163. rapidfireai/frontend/build/static/js/598.a0e792ae.js +1 -0
  164. rapidfireai/frontend/build/static/js/6058.74162bf9.chunk.js +1 -0
  165. rapidfireai/frontend/build/static/js/618.06051134.chunk.js +2 -0
  166. rapidfireai/frontend/build/static/js/618.06051134.chunk.js.LICENSE.txt +21 -0
  167. rapidfireai/frontend/build/static/js/6335.9fca442d.chunk.js +1 -0
  168. rapidfireai/frontend/build/static/js/6336.e05e1154.chunk.js +1 -0
  169. rapidfireai/frontend/build/static/js/6343.2bcd28ff.chunk.js +1 -0
  170. rapidfireai/frontend/build/static/js/6363.a319b8f2.chunk.js +1 -0
  171. rapidfireai/frontend/build/static/js/6478.344abf25.chunk.js +1 -0
  172. rapidfireai/frontend/build/static/js/6504.1c004564.js +1 -0
  173. rapidfireai/frontend/build/static/js/6534.ec7e149b.chunk.js +1 -0
  174. rapidfireai/frontend/build/static/js/6715.55a5c19c.chunk.js +1 -0
  175. rapidfireai/frontend/build/static/js/6756.e6cb993c.chunk.js +2 -0
  176. rapidfireai/frontend/build/static/js/6756.e6cb993c.chunk.js.LICENSE.txt +10 -0
  177. rapidfireai/frontend/build/static/js/6762.acfde9fd.chunk.js +2 -0
  178. rapidfireai/frontend/build/static/js/6762.acfde9fd.chunk.js.LICENSE.txt +19 -0
  179. rapidfireai/frontend/build/static/js/6846.67103d0e.chunk.js +1 -0
  180. rapidfireai/frontend/build/static/js/6861.34cf0198.chunk.js +1 -0
  181. rapidfireai/frontend/build/static/js/6899.0eaf36a8.chunk.js +2 -0
  182. rapidfireai/frontend/build/static/js/6899.0eaf36a8.chunk.js.LICENSE.txt +5 -0
  183. rapidfireai/frontend/build/static/js/6933.8b564944.chunk.js +1 -0
  184. rapidfireai/frontend/build/static/js/699.d0437920.js +1 -0
  185. rapidfireai/frontend/build/static/js/7076.4182f63a.chunk.js +1 -0
  186. rapidfireai/frontend/build/static/js/7186.42ad86d5.chunk.js +1 -0
  187. rapidfireai/frontend/build/static/js/7248.a46635fd.js +1 -0
  188. rapidfireai/frontend/build/static/js/725.6b15a14a.chunk.js +1 -0
  189. rapidfireai/frontend/build/static/js/7266.3575539d.chunk.js +1 -0
  190. rapidfireai/frontend/build/static/js/7270.0a1e84fc.chunk.js +2 -0
  191. rapidfireai/frontend/build/static/js/7270.0a1e84fc.chunk.js.LICENSE.txt +6 -0
  192. rapidfireai/frontend/build/static/js/7367.7120474f.chunk.js +1 -0
  193. rapidfireai/frontend/build/static/js/7436.8e226055.js +1 -0
  194. rapidfireai/frontend/build/static/js/7504.ef223844.chunk.js +1 -0
  195. rapidfireai/frontend/build/static/js/7603.ee049fe3.chunk.js +1 -0
  196. rapidfireai/frontend/build/static/js/7670.2835b49a.chunk.js +2 -0
  197. rapidfireai/frontend/build/static/js/7670.2835b49a.chunk.js.LICENSE.txt +6 -0
  198. rapidfireai/frontend/build/static/js/7721.7390b3cc.chunk.js +1 -0
  199. rapidfireai/frontend/build/static/js/7731.5796cced.chunk.js +1 -0
  200. rapidfireai/frontend/build/static/js/775.660a5deb.chunk.js +2 -0
  201. rapidfireai/frontend/build/static/js/775.660a5deb.chunk.js.LICENSE.txt +6 -0
  202. rapidfireai/frontend/build/static/js/7832.7976a3e4.chunk.js +1 -0
  203. rapidfireai/frontend/build/static/js/7844.72cc2e81.chunk.js +1 -0
  204. rapidfireai/frontend/build/static/js/7948.48eab032.js +1 -0
  205. rapidfireai/frontend/build/static/js/7972.085079d4.chunk.js +2 -0
  206. rapidfireai/frontend/build/static/js/7972.085079d4.chunk.js.LICENSE.txt +6 -0
  207. rapidfireai/frontend/build/static/js/8017.a9e7dc5a.chunk.js +1 -0
  208. rapidfireai/frontend/build/static/js/8023.75f1f3df.js +2 -0
  209. rapidfireai/frontend/build/static/js/8023.75f1f3df.js.LICENSE.txt +41 -0
  210. rapidfireai/frontend/build/static/js/8123.b69db974.js +1 -0
  211. rapidfireai/frontend/build/static/js/813.065a87e5.chunk.js +1 -0
  212. rapidfireai/frontend/build/static/js/819.2056f122.chunk.js +2 -0
  213. rapidfireai/frontend/build/static/js/819.2056f122.chunk.js.LICENSE.txt +6 -0
  214. rapidfireai/frontend/build/static/js/8262.04bc17d1.chunk.js +1 -0
  215. rapidfireai/frontend/build/static/js/8300.75adcc4f.chunk.js +1 -0
  216. rapidfireai/frontend/build/static/js/8336.b1d3e764.chunk.js +1 -0
  217. rapidfireai/frontend/build/static/js/8365.26cf64ea.chunk.js +1 -0
  218. rapidfireai/frontend/build/static/js/8398.8bca8e0e.chunk.js +2 -0
  219. rapidfireai/frontend/build/static/js/8398.8bca8e0e.chunk.js.LICENSE.txt +6 -0
  220. rapidfireai/frontend/build/static/js/847.33ceed50.chunk.js +2 -0
  221. rapidfireai/frontend/build/static/js/847.33ceed50.chunk.js.LICENSE.txt +6 -0
  222. rapidfireai/frontend/build/static/js/8486.8ec852a7.chunk.js +1 -0
  223. rapidfireai/frontend/build/static/js/8497.19378265.chunk.js +1 -0
  224. rapidfireai/frontend/build/static/js/8541.4c55c9f4.chunk.js +1 -0
  225. rapidfireai/frontend/build/static/js/8690.e305a804.chunk.js +2 -0
  226. rapidfireai/frontend/build/static/js/8690.e305a804.chunk.js.LICENSE.txt +6 -0
  227. rapidfireai/frontend/build/static/js/8712.a9445fe6.chunk.js +1 -0
  228. rapidfireai/frontend/build/static/js/8763.61761e08.js +1 -0
  229. rapidfireai/frontend/build/static/js/8823.baf9bffd.chunk.js +2 -0
  230. rapidfireai/frontend/build/static/js/8823.baf9bffd.chunk.js.LICENSE.txt +6 -0
  231. rapidfireai/frontend/build/static/js/8867.767462b7.chunk.js +1 -0
  232. rapidfireai/frontend/build/static/js/8953.c0f88dea.chunk.js +1 -0
  233. rapidfireai/frontend/build/static/js/8960.357cb1eb.chunk.js +2 -0
  234. rapidfireai/frontend/build/static/js/8960.357cb1eb.chunk.js.LICENSE.txt +6 -0
  235. rapidfireai/frontend/build/static/js/9.f4492795.chunk.js +2 -0
  236. rapidfireai/frontend/build/static/js/9.f4492795.chunk.js.LICENSE.txt +12 -0
  237. rapidfireai/frontend/build/static/js/9079.88a8d2a3.js +1 -0
  238. rapidfireai/frontend/build/static/js/9082.37c40520.chunk.js +10 -0
  239. rapidfireai/frontend/build/static/js/9133.90ae330d.js +2 -0
  240. rapidfireai/frontend/build/static/js/9133.90ae330d.js.LICENSE.txt +8 -0
  241. rapidfireai/frontend/build/static/js/9151.1ac359d5.js +2 -0
  242. rapidfireai/frontend/build/static/js/9151.1ac359d5.js.LICENSE.txt +8 -0
  243. rapidfireai/frontend/build/static/js/9168.027bf2fd.chunk.js +1 -0
  244. rapidfireai/frontend/build/static/js/9194.9c5cc548.chunk.js +10 -0
  245. rapidfireai/frontend/build/static/js/9244.026f4aee.chunk.js +1 -0
  246. rapidfireai/frontend/build/static/js/936.2e02d037.js +2 -0
  247. rapidfireai/frontend/build/static/js/936.2e02d037.js.LICENSE.txt +6 -0
  248. rapidfireai/frontend/build/static/js/9369.7d1a0a1d.chunk.js +1 -0
  249. rapidfireai/frontend/build/static/js/9427.7c8442e7.chunk.js +1 -0
  250. rapidfireai/frontend/build/static/js/944.55948859.chunk.js +1 -0
  251. rapidfireai/frontend/build/static/js/9499.c53a82da.js +2 -0
  252. rapidfireai/frontend/build/static/js/9499.c53a82da.js.LICENSE.txt +62 -0
  253. rapidfireai/frontend/build/static/js/9531.3ce05781.chunk.js +1 -0
  254. rapidfireai/frontend/build/static/js/9547.92fac952.chunk.js +2 -0
  255. rapidfireai/frontend/build/static/js/9547.92fac952.chunk.js.LICENSE.txt +6 -0
  256. rapidfireai/frontend/build/static/js/9620.b6e973a7.chunk.js +1 -0
  257. rapidfireai/frontend/build/static/js/9645.6fddfa65.chunk.js +1 -0
  258. rapidfireai/frontend/build/static/js/9669.d38dda6d.js +1 -0
  259. rapidfireai/frontend/build/static/js/9682.41b6b807.chunk.js +1 -0
  260. rapidfireai/frontend/build/static/js/9720.19d5ae76.chunk.js +2 -0
  261. rapidfireai/frontend/build/static/js/9720.19d5ae76.chunk.js.LICENSE.txt +23 -0
  262. rapidfireai/frontend/build/static/js/9723.d3c7fe9e.js +1 -0
  263. rapidfireai/frontend/build/static/js/9780.02a27630.chunk.js +10 -0
  264. rapidfireai/frontend/build/static/js/9808.d0ca9674.chunk.js +2 -0
  265. rapidfireai/frontend/build/static/js/9808.d0ca9674.chunk.js.LICENSE.txt +6 -0
  266. rapidfireai/frontend/build/static/js/9815.b8db3c5d.js +1 -0
  267. rapidfireai/frontend/build/static/js/9886.2940b53a.chunk.js +1 -0
  268. rapidfireai/frontend/build/static/js/main~1f912138.fa9d03b1.js +1 -0
  269. rapidfireai/frontend/build/static/js/main~43dd7041.2e00860d.js +1 -0
  270. rapidfireai/frontend/build/static/js/main~84781932.68deffff.js +1 -0
  271. rapidfireai/frontend/build/static/media/404-overflow.fad9a31861b0afba6f921ebb8e769688.svg +32 -0
  272. rapidfireai/frontend/build/static/media/RapidFire_Square_Bug.27ceb48296314a4bc0d4.png +0 -0
  273. rapidfireai/frontend/build/static/media/chart-bar.0fd4a63680fba840a7b69fbf07969f79.svg +7 -0
  274. rapidfireai/frontend/build/static/media/chart-contour.0d4b306f2669f3ad25375568935e3ce3.svg +5 -0
  275. rapidfireai/frontend/build/static/media/chart-difference.16174216d6f3b7c24f40e3541fe0ca2c.svg +20 -0
  276. rapidfireai/frontend/build/static/media/chart-image.cc434c4dc50780966344e2385a15f8fe.svg +6 -0
  277. rapidfireai/frontend/build/static/media/chart-line.0adaa2036bb4eb5956db6d0c7e925a3d.svg +4 -0
  278. rapidfireai/frontend/build/static/media/chart-parallel.da7dedf539b2af4b654d377c679173e4.svg +7 -0
  279. rapidfireai/frontend/build/static/media/chart-scatter.69118d0023a6ff3973f7fa913834ac47.svg +9 -0
  280. rapidfireai/frontend/build/static/media/default-error.f246ddf367c6fbd67942e5a13382a7f1.svg +26 -0
  281. rapidfireai/frontend/build/static/media/fontawesome-webfont.1e59d2330b4c6deb84b3.ttf +0 -0
  282. rapidfireai/frontend/build/static/media/fontawesome-webfont.20fd1704ea223900efa9.woff2 +0 -0
  283. rapidfireai/frontend/build/static/media/fontawesome-webfont.8b43027f47b20503057d.eot +0 -0
  284. rapidfireai/frontend/build/static/media/fontawesome-webfont.c1e38fd9e0e74ba58f7a.svg +2671 -0
  285. rapidfireai/frontend/build/static/media/fontawesome-webfont.f691f37e57f04c152e23.woff +0 -0
  286. rapidfireai/frontend/build/static/media/icon-visible-fill.8d34cd35303828fdfc15154f5536e63b.svg +7 -0
  287. rapidfireai/frontend/build/static/media/no-experiments.0e4f4a114ef73e7d81c09474aba64b6c.svg +22 -0
  288. rapidfireai/frontend/build/static/media/parallel-chart-placeholder.234ef0c5b220ef2a5a6fa5bafff173f7.svg +16 -0
  289. rapidfireai/frontend/build/static/media/permission-denied-lock.16036747d57cd663d7df223781a447b2.svg +14 -0
  290. rapidfireai/frontend/build/static/media/promo-modal-content.e3b2c6c568ac192b9bec54b838b54850.svg +30 -0
  291. rapidfireai/frontend/build/static/media/registered-model-grey-ok.8274b58d39504c8d1b8c358aa1c9aa35.svg +23 -0
  292. rapidfireai/frontend/build/static/media/warning.290a3b14118933547965e91ea61c5a61.svg +3 -0
  293. rapidfireai/frontend/proxy_middleware.py +233 -0
  294. rapidfireai/frontend/server.py +25 -0
  295. rapidfireai/ml/__init__.py +0 -0
  296. rapidfireai/ml/callbacks.py +176 -0
  297. rapidfireai/ml/checkpoint_utils.py +540 -0
  298. rapidfireai/ml/trainer.py +309 -0
  299. rapidfireai/start.sh +634 -0
  300. rapidfireai/utils/__init__.py +0 -0
  301. rapidfireai/utils/automl_utils.py +51 -0
  302. rapidfireai/utils/constants.py +141 -0
  303. rapidfireai/utils/datapaths.py +69 -0
  304. rapidfireai/utils/exceptions.py +82 -0
  305. rapidfireai/utils/experiment_utils.py +370 -0
  306. rapidfireai/utils/logging.py +87 -0
  307. rapidfireai/utils/mlflow_manager.py +121 -0
  308. rapidfireai/utils/serialize.py +15 -0
  309. rapidfireai/utils/shm_manager.py +469 -0
  310. rapidfireai/utils/trainer_config.py +23 -0
  311. rapidfireai/utils/worker_manager.py +219 -0
  312. rapidfireai/version.py +6 -0
  313. rapidfireai-0.9.10.dist-info/METADATA +247 -0
  314. rapidfireai-0.9.10.dist-info/RECORD +318 -0
  315. rapidfireai-0.9.10.dist-info/entry_points.txt +2 -0
  316. rapidfireai-0.0.1.dist-info/METADATA +0 -37
  317. rapidfireai-0.0.1.dist-info/RECORD +0 -6
  318. {rapidfireai-0.0.1.dist-info → rapidfireai-0.9.10.dist-info}/WHEEL +0 -0
  319. {rapidfireai-0.0.1.dist-info → rapidfireai-0.9.10.dist-info}/licenses/LICENSE +0 -0
  320. {rapidfireai-0.0.1.dist-info → rapidfireai-0.9.10.dist-info}/top_level.txt +0 -0
rapidfireai/__init__.py CHANGED
@@ -1,11 +1,17 @@
1
1
  """
2
- RapidFire AI - Coming Soon
2
+ RapidFire AI
3
3
  """
4
4
 
5
- __version__ = "0.0.1"
6
- __author__ = "Pradyumna Sridhara"
7
- __email__ = "pradyumna@rapidfire.ai"
5
+ from .version import __version__, __version_info__
6
+
7
+ __author__ = "RapidFire AI Inc."
8
+ __email__ = "support@rapidfire.ai"
9
+
10
+ from rapidfireai.experiment import Experiment
11
+
8
12
 
9
13
  def coming_soon():
10
14
  """Placeholder function - full functionality coming soon."""
11
- return "RapidFire AI package is under development. Stay tuned!"
15
+ return "RapidFire AI package is under development. Stay tuned!"
16
+
17
+ __all__ = ["Experiment"]
@@ -0,0 +1,20 @@
1
+ """AutoML module for hyperparameter optimization."""
2
+
3
+ from .base import AutoMLAlgorithm
4
+ from .datatypes import List, Range
5
+ from .grid_search import RFGridSearch
6
+ from .model_config import RFDPOConfig, RFGRPOConfig, RFLoraConfig, RFModelConfig, RFSFTConfig
7
+ from .random_search import RFRandomSearch
8
+
9
+ __all__ = [
10
+ "List",
11
+ "Range",
12
+ "RFGridSearch",
13
+ "RFRandomSearch",
14
+ "AutoMLAlgorithm",
15
+ "RFModelConfig",
16
+ "RFLoraConfig",
17
+ "RFSFTConfig",
18
+ "RFDPOConfig",
19
+ "RFGRPOConfig",
20
+ ]
@@ -0,0 +1,48 @@
1
+ """Base classes and configurations for AutoML algorithms."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Any
5
+
6
+ from rapidfireai.automl.datatypes import List
7
+ from rapidfireai.automl.model_config import RFModelConfig
8
+ from rapidfireai.utils.exceptions import AutoMLException
9
+
10
+
11
+ class AutoMLAlgorithm(ABC):
12
+ """Base class for AutoML algorithms."""
13
+
14
+ VALID_TRAINER_TYPES = {"SFT", "DPO", "GRPO"}
15
+
16
+ def __init__(self, configs=None, create_model_fn=None, trainer_type: str = "SFT", num_runs: int = 1):
17
+ """Initialize AutoML algorithm with configurations and trainer type."""
18
+ try:
19
+ self.configs = self._normalize_configs(configs)
20
+ self.trainer_type = trainer_type.upper()
21
+ self.num_runs = num_runs
22
+
23
+ if self.trainer_type not in self.VALID_TRAINER_TYPES:
24
+ raise AutoMLException(f"trainer_type must be one of {self.VALID_TRAINER_TYPES}")
25
+
26
+ self._validate_configs()
27
+ except Exception as e:
28
+ raise AutoMLException(f"Error initializing {self.__class__.__name__}: {e}") from e
29
+
30
+ def _normalize_configs(self, configs):
31
+ """Normalize configs to list format."""
32
+ if isinstance(configs, List):
33
+ return configs.values
34
+ elif isinstance(configs, list):
35
+ return configs
36
+ return [configs] if configs else []
37
+
38
+ def _validate_configs(self):
39
+ """Validate all configs are RFModelConfig instances."""
40
+ for config in self.configs:
41
+ if not isinstance(config, RFModelConfig):
42
+ raise AutoMLException(f"All configs must be RFModelConfig instances, got {type(config)}")
43
+
44
+ @abstractmethod
45
+ def get_runs(self, seed: int) -> list[dict[str, Any]]:
46
+ """Generate hyperparameter combinations for different training configurations."""
47
+ if not isinstance(seed, int) or seed < 0:
48
+ raise AutoMLException("seed must be a non-negative integer")
@@ -0,0 +1,42 @@
1
+ """Contains classes for representing hyperparameter data types."""
2
+
3
+ import random
4
+
5
+ # TODO: need to set seed for random module.
6
+ # TODO: List.sample() will not work for nested lists.
7
+ # TODO: add support for sampling methods like 'uniform' and 'loguniform'.
8
+
9
+
10
+ class Range:
11
+ """Represents a range of values for a hyperparameter."""
12
+
13
+ def __init__(self, start, end, dtype: str | None = None):
14
+ if dtype is None:
15
+ self.dtype = "int" if isinstance(start, int) and isinstance(end, int) else "float"
16
+ else:
17
+ if dtype not in ("int", "float"):
18
+ raise ValueError("dtype must be either 'int' or 'float'.")
19
+ self.dtype = dtype
20
+ if not (isinstance(start, (int, float)) and isinstance(end, (int, float))):
21
+ raise ValueError("start and end must be either int or float.")
22
+ self.start = start
23
+ self.end = end
24
+
25
+ def sample(self):
26
+ """Sample a value from the range [self.start, self.end]."""
27
+ if self.dtype == "int":
28
+ return random.randint(self.start, self.end)
29
+ return random.uniform(self.start, self.end)
30
+
31
+
32
+ class List:
33
+ """Represents a list of values for a hyperparameter."""
34
+
35
+ def __init__(self, values):
36
+ if not isinstance(values, list):
37
+ raise ValueError("List expects a list of values.")
38
+ self.values = values
39
+
40
+ def sample(self):
41
+ """Sample a value from the list."""
42
+ return random.choice(self.values)
@@ -0,0 +1,125 @@
1
+ """Grid search implementation for AutoML training configurations."""
2
+
3
+ from itertools import product
4
+ from typing import Any, Dict
5
+ from typing import List as ListType
6
+
7
+ from rapidfireai.automl.base import AutoMLAlgorithm
8
+ from rapidfireai.automl.datatypes import List
9
+ from rapidfireai.utils.exceptions import AutoMLException
10
+
11
+
12
+ def recursive_expand_gridsearch(item: Any):
13
+ """Recursively expand nested structures with List datatypes into all combinations."""
14
+ if isinstance(item, dict):
15
+ keys = list(item.keys())
16
+ value_lists = [list(recursive_expand_gridsearch(item[k])) for k in keys]
17
+ for values in product(*value_lists):
18
+ yield dict(zip(keys, values))
19
+ elif isinstance(item, List):
20
+ for value in item.values:
21
+ yield from recursive_expand_gridsearch(value)
22
+ else:
23
+ yield item
24
+
25
+
26
+ class RFGridSearch(AutoMLAlgorithm):
27
+ """Grid search algorithm that generates all hyperparameter combinations."""
28
+
29
+ def get_runs(self, seed: int) -> ListType[Dict[str, Any]]:
30
+ """Generate all possible hyperparameter combinations for grid search."""
31
+ if not isinstance(seed, int) or seed < 0:
32
+ raise AutoMLException("seed must be a non-negative integer")
33
+
34
+ try:
35
+ runs = []
36
+ for config in self.configs:
37
+ if config.peft_config is None:
38
+ peft_configs = [None]
39
+ elif isinstance(config.peft_config, List):
40
+ peft_configs = config.peft_config.values
41
+ elif isinstance(config.peft_config, list):
42
+ peft_configs = config.peft_config
43
+ else:
44
+ peft_configs = [config.peft_config]
45
+
46
+ for peft_config in peft_configs:
47
+ peft_instances = (
48
+ [{}] if peft_config is None else list(recursive_expand_gridsearch(peft_config._user_params))
49
+ )
50
+ training_instances = (
51
+ [{}]
52
+ if config.training_args is None
53
+ else list(recursive_expand_gridsearch(config.training_args._user_params))
54
+ )
55
+ model_kwargs_instances = (
56
+ [{}] if config.model_kwargs is None else list(recursive_expand_gridsearch(config.model_kwargs))
57
+ )
58
+ ref_model_kwargs_instances = (
59
+ [{}]
60
+ if config.ref_model_kwargs is None
61
+ else list(recursive_expand_gridsearch(config.ref_model_kwargs))
62
+ )
63
+ reward_funcs_instances = (
64
+ [{}] if config.reward_funcs is None else list(recursive_expand_gridsearch(config.reward_funcs))
65
+ )
66
+
67
+ # Get additional kwargs for Trainer
68
+ # FIXME: this is a hack to get the additional kwargs, we should find a better way to do this
69
+ excluded_attrs = {
70
+ "model_name",
71
+ "tokenizer",
72
+ "tokenizer_kwargs",
73
+ "model_type",
74
+ "model_kwargs",
75
+ "peft_config",
76
+ "training_args",
77
+ "ref_model_name",
78
+ "ref_model_type",
79
+ "ref_model_kwargs",
80
+ "reward_funcs",
81
+ }
82
+ # excluded_attrs = set(config.__dict__.keys()) - set(config.__annotations__.keys())
83
+ additional_kwargs = {
84
+ k: v for k, v in config.__dict__.items() if k not in excluded_attrs and v is not None
85
+ }
86
+ additional_kwargs_instances = (
87
+ [{}] if not additional_kwargs else list(recursive_expand_gridsearch(additional_kwargs))
88
+ )
89
+
90
+ # Generate gridsearch combinations
91
+ for peft_params in peft_instances:
92
+ for training_params in training_instances:
93
+ for model_kwargs in model_kwargs_instances:
94
+ for additional_kwargs in additional_kwargs_instances:
95
+ leaf = {
96
+ "trainer_type": self.trainer_type,
97
+ "training_args": training_params,
98
+ "peft_params": peft_params,
99
+ "model_name": config.model_name,
100
+ "tokenizer": config.tokenizer,
101
+ "tokenizer_kwargs": config.tokenizer_kwargs,
102
+ "model_type": config.model_type,
103
+ "model_kwargs": model_kwargs,
104
+ "additional_kwargs": additional_kwargs,
105
+ }
106
+
107
+ if self.trainer_type == "DPO":
108
+ leaf["ref_model_config"] = {
109
+ "model_name": config.ref_model_name,
110
+ "model_type": config.ref_model_type,
111
+ }
112
+ for ref_model_kwargs in ref_model_kwargs_instances:
113
+ leaf["ref_model_config"]["model_kwargs"] = ref_model_kwargs
114
+ runs.append(leaf)
115
+ elif self.trainer_type == "GRPO":
116
+ for reward_func in reward_funcs_instances:
117
+ leaf["reward_funcs"] = reward_func
118
+ runs.append(leaf)
119
+ else:
120
+ runs.append(leaf)
121
+
122
+ return runs
123
+
124
+ except Exception as e:
125
+ raise AutoMLException(f"Error generating runs: {e}") from e
@@ -0,0 +1,102 @@
1
+ """Model configuration for AutoML training."""
2
+
3
+ import inspect
4
+ import copy
5
+ from dataclasses import dataclass
6
+ from typing import Any, Callable, Optional, Type, Union, get_type_hints
7
+
8
+ from peft import LoraConfig
9
+ from trl import DPOConfig, GRPOConfig, SFTConfig
10
+
11
+ from rapidfireai.automl.datatypes import List, Range
12
+
13
+ def _create_rf_class(base_class: Type, class_name: str):
14
+ """Creating a RF class that dynamically inherits all constructor parameters and supports singleton, list, and Range values."""
15
+ if not inspect.isclass(base_class):
16
+ raise ValueError(f"base_class must be a class, got {type(base_class)}")
17
+
18
+ sig = inspect.signature(base_class.__init__)
19
+ constructor_params = [p for p in sig.parameters.keys() if p != "self"]
20
+
21
+ type_hints = get_type_hints(base_class)
22
+ new_type_hints = {}
23
+
24
+ for param_name, param_type in type_hints.items():
25
+ if param_name in constructor_params:
26
+ new_type_hints[param_name] = param_type | List | Range
27
+
28
+ def __init__(self, **kwargs):
29
+ self._user_params = copy.deepcopy(kwargs)
30
+ self._constructor_params = constructor_params
31
+ self._initializing = True
32
+
33
+ parent_kwargs = {}
34
+ for key, value in kwargs.items():
35
+ if not isinstance(value, (List, Range)):
36
+ parent_kwargs[key] = value
37
+
38
+ base_class.__init__(self, **parent_kwargs)
39
+
40
+ self._initializing = False
41
+ def copy_config(self):
42
+ """Create a deep copy of the configuration."""
43
+ copied_params = copy.deepcopy(self._user_params)
44
+ new_instance = self.__class__(**copied_params)
45
+
46
+ return new_instance
47
+
48
+ def __setattr__(self, name, value):
49
+ """Override setattr to update _user_params when constructor parameters are modified."""
50
+
51
+ if (hasattr(self, '_constructor_params') and
52
+ name in self._constructor_params and
53
+ hasattr(self, '_user_params') and
54
+ name in self._user_params and
55
+ not getattr(self, '_initializing', True)): # Don't update during init
56
+ self._user_params[name] = value
57
+
58
+ base_class.__setattr__(self, name, value)
59
+
60
+
61
+ return type(
62
+ class_name,
63
+ (base_class,),
64
+ {
65
+ "__doc__": f"RF version of {base_class.__name__}",
66
+ "__annotations__": new_type_hints,
67
+ "__init__": __init__,
68
+ "copy": copy_config,
69
+ "__setattr__": __setattr__
70
+ },
71
+ )
72
+
73
+
74
+ # Create RF wrapper classes for external libraries
75
+ RFLoraConfig = _create_rf_class(LoraConfig, "RFLoraConfig")
76
+ RFSFTConfig = _create_rf_class(SFTConfig, "RFSFTConfig")
77
+ RFDPOConfig = _create_rf_class(DPOConfig, "RFDPOConfig")
78
+ RFGRPOConfig = _create_rf_class(GRPOConfig, "RFGRPOConfig")
79
+
80
+
81
+ @dataclass
82
+ class RFModelConfig:
83
+ """Model configuration for AutoML training."""
84
+
85
+ model_name: str = None
86
+ tokenizer: Optional[str] = None
87
+ tokenizer_kwargs: Optional[dict[str, Any]] = None
88
+ formatting_func: Optional[Union[Callable, List]] = None
89
+ compute_metrics: Optional[Union[Callable, List]] = None
90
+ peft_config: Optional[Union[RFLoraConfig, List]] = None
91
+ training_args: Optional[Union[RFSFTConfig, RFDPOConfig, RFGRPOConfig]] = None
92
+ model_type: Optional[str] = "causal_lm"
93
+ model_kwargs: Optional[dict[str, Any]] = None
94
+ ref_model_name: Optional[str] = None
95
+ ref_model_type: Optional[str] = None
96
+ ref_model_kwargs: Optional[dict[str, Any]] = None
97
+ reward_funcs: Optional[Union[str, List, Callable, Any]] = None
98
+ generation_config: Optional[dict[str, Any]] = None
99
+
100
+ def copy(self):#FIXME: Handle similar to create_rf_class
101
+ """Create a deep copy of the RFModelConfig."""
102
+ return copy.deepcopy(self)
@@ -0,0 +1,145 @@
1
+ """Random search implementation for AutoML hyperparameter optimization."""
2
+
3
+ import random
4
+ import json
5
+ from itertools import product
6
+ from typing import Any, Dict
7
+ from typing import List as ListType
8
+
9
+ from rapidfireai.automl.base import AutoMLAlgorithm
10
+ from rapidfireai.automl.datatypes import List, Range
11
+ from rapidfireai.utils.exceptions import AutoMLException
12
+ from rapidfireai.utils.serialize import encode_payload
13
+
14
+
15
+ def recursive_expand_randomsearch(item: Any):
16
+ if isinstance(item, dict):
17
+ return {k: recursive_expand_randomsearch(v) for k, v in item.items()}
18
+ elif isinstance(item, List):
19
+ return item.sample()
20
+ elif isinstance(item, Range):
21
+ return item.sample()
22
+ else:
23
+ return item
24
+
25
+ class RFRandomSearch(AutoMLAlgorithm):
26
+ """Random search algorithm that samples num_runs hyperparameter combinations."""
27
+
28
+ def get_runs(self, seed: int=42) -> ListType[Dict[str, Any]]:
29
+ """Generate num_runs random hyperparameter combinations."""
30
+ if seed is not None and (not isinstance(seed, int) or seed < 0):
31
+ raise AutoMLException("seed must be a non-negative integer")
32
+
33
+ if not isinstance(self.num_runs, int) or self.num_runs <= 0:
34
+ raise AutoMLException("num_runs must be a positive integer")
35
+
36
+
37
+ random.seed(seed)
38
+
39
+ try:
40
+ runs = []
41
+ seen_configs = set()
42
+ max_attempts = self.num_runs * 10
43
+ attempts = 0
44
+
45
+ while len(runs) < self.num_runs and attempts < max_attempts:
46
+ attempts += 1
47
+
48
+ config = List(self.configs).sample()
49
+
50
+ if config.peft_config is None:
51
+ selected_peft_config = None
52
+ elif isinstance(config.peft_config, list):
53
+ selected_peft_config = List(config.peft_config).sample()
54
+ elif isinstance(config.peft_config, List):
55
+ selected_peft_config = config.peft_config.sample()
56
+ else:
57
+ selected_peft_config = config.peft_config
58
+
59
+ peft_params = (
60
+ {} if selected_peft_config is None
61
+ else recursive_expand_randomsearch(selected_peft_config._user_params)
62
+ )
63
+
64
+
65
+ # Sample other parameters
66
+ training_params = (
67
+ {} if config.training_args is None
68
+ else recursive_expand_randomsearch(config.training_args._user_params)
69
+ )
70
+
71
+ model_kwargs = (
72
+ {} if config.model_kwargs is None
73
+ else recursive_expand_randomsearch(config.model_kwargs)
74
+ )
75
+
76
+ ref_model_kwargs = (
77
+ {} if config.ref_model_kwargs is None
78
+ else recursive_expand_randomsearch(config.ref_model_kwargs)
79
+ )
80
+
81
+ reward_funcs = (
82
+ {} if config.reward_funcs is None
83
+ else recursive_expand_randomsearch(config.reward_funcs)
84
+ )
85
+
86
+ # FIXME: avoid hardcoding the excluded attributes
87
+ excluded_attrs = {
88
+ "model_name",
89
+ "tokenizer",
90
+ "tokenizer_kwargs",
91
+ "model_type",
92
+ "model_kwargs",
93
+ "peft_config",
94
+ "training_args",
95
+ "ref_model_name",
96
+ "ref_model_type",
97
+ "ref_model_kwargs",
98
+ "reward_funcs",
99
+ }
100
+ additional_kwargs = {
101
+ k: v for k, v in config.__dict__.items() if k not in excluded_attrs and v is not None
102
+ }
103
+ additional_kwargs_sampled = (
104
+ {} if not additional_kwargs
105
+ else recursive_expand_randomsearch(additional_kwargs)
106
+ )
107
+
108
+ leaf = {
109
+ "trainer_type": self.trainer_type,
110
+ "training_args": training_params,
111
+ "peft_params": peft_params,
112
+ "model_name": config.model_name,
113
+ "tokenizer": config.tokenizer,
114
+ "tokenizer_kwargs": config.tokenizer_kwargs,
115
+ "model_type": config.model_type,
116
+ "model_kwargs": model_kwargs,
117
+ "additional_kwargs": additional_kwargs_sampled,
118
+ }
119
+
120
+ if self.trainer_type == "DPO":
121
+ leaf["ref_model_config"] = {
122
+ "model_name": config.ref_model_name,
123
+ "model_type": config.ref_model_type,
124
+ "model_kwargs": ref_model_kwargs,
125
+ }
126
+ #FIXME: correct ref args
127
+ elif self.trainer_type == "GRPO":
128
+ leaf["reward_funcs"] = reward_funcs
129
+
130
+ # Check for duplicates using hashable representation
131
+ config_hash = encode_payload(leaf)
132
+ if config_hash not in seen_configs:
133
+ seen_configs.add(config_hash)
134
+ runs.append(leaf)
135
+
136
+ if len(runs) < self.num_runs:
137
+ raise AutoMLException(
138
+ f"Could not generate {self.num_runs} unique configurations. "
139
+ f"Generated {len(runs)} unique configs after {attempts} attempts. "
140
+ )
141
+
142
+ return runs
143
+
144
+ except Exception as e:
145
+ raise AutoMLException(f"Error generating runs: {e}") from e
File without changes
@@ -0,0 +1,63 @@
1
+ """This module contains the DatasetChunker class which is responsible for chunking a PyTorch Dataset
2
+ into chunks for distributed processing."""
3
+
4
+ from datasets import Dataset
5
+
6
+
7
+ class DatasetChunks:
8
+ """Chunks a HuggingFace Dataset into n_chunks for distributed processing."""
9
+
10
+ def __init__(self, dataset: Dataset, n_chunks: int):
11
+ self.dataset = dataset
12
+ self.n_chunks = n_chunks
13
+ self.dataset_size = len(dataset)
14
+
15
+ # Validate n_chunks
16
+ if n_chunks <= 0:
17
+ raise ValueError(f"n_chunks must be positive, got {n_chunks}")
18
+
19
+ # Calculate base size for even distribution (not chunk_size anymore)
20
+ self.base_size = self.dataset_size // n_chunks
21
+ self.extra_items = self.dataset_size % n_chunks
22
+ self.chunk_indices = self._create_chunk_indices()
23
+
24
+ def _create_chunk_indices(self):
25
+ """Create start/end index pairs for each chunk, distributing items as evenly as possible."""
26
+ chunks = {}
27
+
28
+ # Calculate base size and number of chunks that get an extra item
29
+ base_size = self.dataset_size // self.n_chunks
30
+ extra_items = self.dataset_size % self.n_chunks
31
+
32
+ current_idx = 0
33
+ for chunk_id in range(self.n_chunks):
34
+ # First 'extra_items' chunks get base_size + 1, rest get base_size
35
+ chunk_size = base_size + (1 if chunk_id < extra_items else 0)
36
+
37
+ if chunk_size > 0: # Only create non-empty chunks
38
+ chunks[chunk_id] = (current_idx, current_idx + chunk_size)
39
+ current_idx += chunk_size
40
+
41
+ return chunks
42
+
43
+ def get_chunk(self, chunk_id: int) -> Dataset:
44
+ """Get a chunk as a HuggingFace Dataset subset."""
45
+ if chunk_id not in self.chunk_indices:
46
+ raise ValueError(f"Invalid chunk_id {chunk_id}. Valid range: 0-{len(self.chunk_indices) - 1}")
47
+
48
+ start_idx, end_idx = self.chunk_indices[chunk_id]
49
+ # Use HuggingFace Dataset's select method to create a subset
50
+ indices = list(range(start_idx, end_idx))
51
+ return self.dataset.select(indices)
52
+
53
+ def get_chunk_size(self, chunk_id: int) -> int:
54
+ """Get the size of a specific chunk."""
55
+ if chunk_id not in self.chunk_indices:
56
+ raise ValueError(f"Invalid chunk_id {chunk_id}")
57
+ start_idx, end_idx = self.chunk_indices[chunk_id]
58
+ return end_idx - start_idx
59
+
60
+ @property
61
+ def chunk_ids(self):
62
+ """Get all available chunk IDs."""
63
+ return list(self.chunk_indices.keys())