easy_ml 0.1.4 → 0.2.0.pre.rc1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (239) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +234 -26
  3. data/Rakefile +45 -0
  4. data/app/controllers/easy_ml/application_controller.rb +67 -0
  5. data/app/controllers/easy_ml/columns_controller.rb +38 -0
  6. data/app/controllers/easy_ml/datasets_controller.rb +156 -0
  7. data/app/controllers/easy_ml/datasources_controller.rb +88 -0
  8. data/app/controllers/easy_ml/deploys_controller.rb +20 -0
  9. data/app/controllers/easy_ml/models_controller.rb +151 -0
  10. data/app/controllers/easy_ml/retraining_runs_controller.rb +19 -0
  11. data/app/controllers/easy_ml/settings_controller.rb +59 -0
  12. data/app/frontend/components/AlertProvider.tsx +108 -0
  13. data/app/frontend/components/DatasetPreview.tsx +161 -0
  14. data/app/frontend/components/EmptyState.tsx +28 -0
  15. data/app/frontend/components/ModelCard.tsx +255 -0
  16. data/app/frontend/components/ModelDetails.tsx +334 -0
  17. data/app/frontend/components/ModelForm.tsx +384 -0
  18. data/app/frontend/components/Navigation.tsx +300 -0
  19. data/app/frontend/components/Pagination.tsx +72 -0
  20. data/app/frontend/components/Popover.tsx +55 -0
  21. data/app/frontend/components/PredictionStream.tsx +105 -0
  22. data/app/frontend/components/ScheduleModal.tsx +726 -0
  23. data/app/frontend/components/SearchInput.tsx +23 -0
  24. data/app/frontend/components/SearchableSelect.tsx +132 -0
  25. data/app/frontend/components/dataset/AutosaveIndicator.tsx +39 -0
  26. data/app/frontend/components/dataset/ColumnConfigModal.tsx +431 -0
  27. data/app/frontend/components/dataset/ColumnFilters.tsx +256 -0
  28. data/app/frontend/components/dataset/ColumnList.tsx +101 -0
  29. data/app/frontend/components/dataset/FeatureConfigPopover.tsx +57 -0
  30. data/app/frontend/components/dataset/FeaturePicker.tsx +205 -0
  31. data/app/frontend/components/dataset/PreprocessingConfig.tsx +704 -0
  32. data/app/frontend/components/dataset/SplitConfigurator.tsx +120 -0
  33. data/app/frontend/components/dataset/splitters/DateSplitter.tsx +58 -0
  34. data/app/frontend/components/dataset/splitters/KFoldSplitter.tsx +68 -0
  35. data/app/frontend/components/dataset/splitters/LeavePOutSplitter.tsx +29 -0
  36. data/app/frontend/components/dataset/splitters/PredefinedSplitter.tsx +146 -0
  37. data/app/frontend/components/dataset/splitters/RandomSplitter.tsx +85 -0
  38. data/app/frontend/components/dataset/splitters/StratifiedSplitter.tsx +79 -0
  39. data/app/frontend/components/dataset/splitters/constants.ts +77 -0
  40. data/app/frontend/components/dataset/splitters/types.ts +168 -0
  41. data/app/frontend/components/dataset/splitters/utils.ts +53 -0
  42. data/app/frontend/components/features/CodeEditor.tsx +46 -0
  43. data/app/frontend/components/features/DataPreview.tsx +150 -0
  44. data/app/frontend/components/features/FeatureCard.tsx +88 -0
  45. data/app/frontend/components/features/FeatureForm.tsx +235 -0
  46. data/app/frontend/components/features/FeatureGroupCard.tsx +54 -0
  47. data/app/frontend/components/settings/PluginSettings.tsx +81 -0
  48. data/app/frontend/components/ui/badge.tsx +44 -0
  49. data/app/frontend/components/ui/collapsible.tsx +9 -0
  50. data/app/frontend/components/ui/scroll-area.tsx +46 -0
  51. data/app/frontend/components/ui/separator.tsx +29 -0
  52. data/app/frontend/entrypoints/App.tsx +40 -0
  53. data/app/frontend/entrypoints/Application.tsx +24 -0
  54. data/app/frontend/hooks/useAutosave.ts +61 -0
  55. data/app/frontend/layouts/Layout.tsx +38 -0
  56. data/app/frontend/lib/utils.ts +6 -0
  57. data/app/frontend/mockData.ts +272 -0
  58. data/app/frontend/pages/DatasetDetailsPage.tsx +103 -0
  59. data/app/frontend/pages/DatasetsPage.tsx +261 -0
  60. data/app/frontend/pages/DatasourceFormPage.tsx +147 -0
  61. data/app/frontend/pages/DatasourcesPage.tsx +261 -0
  62. data/app/frontend/pages/EditModelPage.tsx +45 -0
  63. data/app/frontend/pages/EditTransformationPage.tsx +56 -0
  64. data/app/frontend/pages/ModelsPage.tsx +115 -0
  65. data/app/frontend/pages/NewDatasetPage.tsx +366 -0
  66. data/app/frontend/pages/NewModelPage.tsx +45 -0
  67. data/app/frontend/pages/NewTransformationPage.tsx +43 -0
  68. data/app/frontend/pages/SettingsPage.tsx +272 -0
  69. data/app/frontend/pages/ShowModelPage.tsx +30 -0
  70. data/app/frontend/pages/TransformationsPage.tsx +95 -0
  71. data/app/frontend/styles/application.css +100 -0
  72. data/app/frontend/types/dataset.ts +146 -0
  73. data/app/frontend/types/datasource.ts +33 -0
  74. data/app/frontend/types/preprocessing.ts +1 -0
  75. data/app/frontend/types.ts +113 -0
  76. data/app/helpers/easy_ml/application_helper.rb +10 -0
  77. data/app/jobs/easy_ml/application_job.rb +21 -0
  78. data/app/jobs/easy_ml/batch_job.rb +46 -0
  79. data/app/jobs/easy_ml/compute_feature_job.rb +19 -0
  80. data/app/jobs/easy_ml/deploy_job.rb +13 -0
  81. data/app/jobs/easy_ml/finalize_feature_job.rb +15 -0
  82. data/app/jobs/easy_ml/refresh_dataset_job.rb +32 -0
  83. data/app/jobs/easy_ml/schedule_retraining_job.rb +11 -0
  84. data/app/jobs/easy_ml/sync_datasource_job.rb +17 -0
  85. data/app/jobs/easy_ml/training_job.rb +62 -0
  86. data/app/models/easy_ml/adapters/base_adapter.rb +45 -0
  87. data/app/models/easy_ml/adapters/polars_adapter.rb +77 -0
  88. data/app/models/easy_ml/cleaner.rb +82 -0
  89. data/app/models/easy_ml/column.rb +124 -0
  90. data/app/models/easy_ml/column_history.rb +30 -0
  91. data/app/models/easy_ml/column_list.rb +122 -0
  92. data/app/models/easy_ml/concerns/configurable.rb +61 -0
  93. data/app/models/easy_ml/concerns/versionable.rb +19 -0
  94. data/app/models/easy_ml/dataset.rb +767 -0
  95. data/app/models/easy_ml/dataset_history.rb +56 -0
  96. data/app/models/easy_ml/datasource.rb +182 -0
  97. data/app/models/easy_ml/datasource_history.rb +24 -0
  98. data/app/models/easy_ml/datasources/base_datasource.rb +54 -0
  99. data/app/models/easy_ml/datasources/file_datasource.rb +58 -0
  100. data/app/models/easy_ml/datasources/polars_datasource.rb +89 -0
  101. data/app/models/easy_ml/datasources/s3_datasource.rb +97 -0
  102. data/app/models/easy_ml/deploy.rb +114 -0
  103. data/app/models/easy_ml/event.rb +79 -0
  104. data/app/models/easy_ml/feature.rb +437 -0
  105. data/app/models/easy_ml/feature_history.rb +38 -0
  106. data/app/models/easy_ml/model.rb +575 -41
  107. data/app/models/easy_ml/model_file.rb +133 -0
  108. data/app/models/easy_ml/model_file_history.rb +24 -0
  109. data/app/models/easy_ml/model_history.rb +51 -0
  110. data/app/models/easy_ml/models/base_model.rb +58 -0
  111. data/app/models/easy_ml/models/hyperparameters/base.rb +99 -0
  112. data/app/models/easy_ml/models/hyperparameters/xgboost/dart.rb +82 -0
  113. data/app/models/easy_ml/models/hyperparameters/xgboost/gblinear.rb +82 -0
  114. data/app/models/easy_ml/models/hyperparameters/xgboost/gbtree.rb +97 -0
  115. data/app/models/easy_ml/models/hyperparameters/xgboost.rb +71 -0
  116. data/app/models/easy_ml/models/xgboost/evals_callback.rb +138 -0
  117. data/app/models/easy_ml/models/xgboost/progress_callback.rb +39 -0
  118. data/app/models/easy_ml/models/xgboost.rb +544 -5
  119. data/app/models/easy_ml/prediction.rb +44 -0
  120. data/app/models/easy_ml/retraining_job.rb +278 -0
  121. data/app/models/easy_ml/retraining_run.rb +184 -0
  122. data/app/models/easy_ml/settings.rb +37 -0
  123. data/app/models/easy_ml/splitter.rb +90 -0
  124. data/app/models/easy_ml/splitters/base_splitter.rb +28 -0
  125. data/app/models/easy_ml/splitters/date_splitter.rb +91 -0
  126. data/app/models/easy_ml/splitters/predefined_splitter.rb +74 -0
  127. data/app/models/easy_ml/splitters/random_splitter.rb +82 -0
  128. data/app/models/easy_ml/tuner_job.rb +56 -0
  129. data/app/models/easy_ml/tuner_run.rb +31 -0
  130. data/app/models/splitter_history.rb +6 -0
  131. data/app/serializers/easy_ml/column_serializer.rb +27 -0
  132. data/app/serializers/easy_ml/dataset_serializer.rb +73 -0
  133. data/app/serializers/easy_ml/datasource_serializer.rb +64 -0
  134. data/app/serializers/easy_ml/feature_serializer.rb +27 -0
  135. data/app/serializers/easy_ml/model_serializer.rb +90 -0
  136. data/app/serializers/easy_ml/retraining_job_serializer.rb +22 -0
  137. data/app/serializers/easy_ml/retraining_run_serializer.rb +39 -0
  138. data/app/serializers/easy_ml/settings_serializer.rb +9 -0
  139. data/app/views/layouts/easy_ml/application.html.erb +15 -0
  140. data/config/initializers/resque.rb +3 -0
  141. data/config/resque-pool.yml +6 -0
  142. data/config/routes.rb +39 -0
  143. data/config/spring.rb +1 -0
  144. data/config/vite.json +15 -0
  145. data/lib/easy_ml/configuration.rb +64 -0
  146. data/lib/easy_ml/core/evaluators/base_evaluator.rb +53 -0
  147. data/lib/easy_ml/core/evaluators/classification_evaluators.rb +126 -0
  148. data/lib/easy_ml/core/evaluators/regression_evaluators.rb +66 -0
  149. data/lib/easy_ml/core/model_evaluator.rb +161 -89
  150. data/lib/easy_ml/core/tuner/adapters/base_adapter.rb +28 -18
  151. data/lib/easy_ml/core/tuner/adapters/xgboost_adapter.rb +4 -25
  152. data/lib/easy_ml/core/tuner.rb +123 -62
  153. data/lib/easy_ml/core.rb +0 -3
  154. data/lib/easy_ml/core_ext/hash.rb +24 -0
  155. data/lib/easy_ml/core_ext/pathname.rb +11 -5
  156. data/lib/easy_ml/data/date_converter.rb +90 -0
  157. data/lib/easy_ml/data/filter_extensions.rb +31 -0
  158. data/lib/easy_ml/data/polars_column.rb +126 -0
  159. data/lib/easy_ml/data/polars_reader.rb +297 -0
  160. data/lib/easy_ml/data/preprocessor.rb +280 -142
  161. data/lib/easy_ml/data/simple_imputer.rb +255 -0
  162. data/lib/easy_ml/data/splits/file_split.rb +252 -0
  163. data/lib/easy_ml/data/splits/in_memory_split.rb +54 -0
  164. data/lib/easy_ml/data/splits/split.rb +95 -0
  165. data/lib/easy_ml/data/splits.rb +9 -0
  166. data/lib/easy_ml/data/statistics_learner.rb +93 -0
  167. data/lib/easy_ml/data/synced_directory.rb +341 -0
  168. data/lib/easy_ml/data.rb +6 -2
  169. data/lib/easy_ml/engine.rb +105 -6
  170. data/lib/easy_ml/feature_store.rb +227 -0
  171. data/lib/easy_ml/features.rb +61 -0
  172. data/lib/easy_ml/initializers/inflections.rb +17 -3
  173. data/lib/easy_ml/logging.rb +2 -2
  174. data/lib/easy_ml/predict.rb +74 -0
  175. data/lib/easy_ml/railtie/generators/migration/migration_generator.rb +192 -36
  176. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_column_histories.rb.tt +9 -0
  177. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_columns.rb.tt +25 -0
  178. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_dataset_histories.rb.tt +9 -0
  179. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_datasets.rb.tt +31 -0
  180. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_datasource_histories.rb.tt +9 -0
  181. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_datasources.rb.tt +16 -0
  182. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_deploys.rb.tt +24 -0
  183. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_events.rb.tt +20 -0
  184. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_feature_histories.rb.tt +14 -0
  185. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_features.rb.tt +32 -0
  186. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_model_file_histories.rb.tt +9 -0
  187. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_model_files.rb.tt +17 -0
  188. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_model_histories.rb.tt +9 -0
  189. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_models.rb.tt +20 -9
  190. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_predictions.rb.tt +17 -0
  191. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_retraining_jobs.rb.tt +77 -0
  192. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_settings.rb.tt +9 -0
  193. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_splitter_histories.rb.tt +9 -0
  194. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_splitters.rb.tt +15 -0
  195. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_tuner_jobs.rb.tt +40 -0
  196. data/lib/easy_ml/support/est.rb +5 -1
  197. data/lib/easy_ml/support/file_rotate.rb +79 -15
  198. data/lib/easy_ml/support/file_support.rb +9 -0
  199. data/lib/easy_ml/support/local_file.rb +24 -0
  200. data/lib/easy_ml/support/lockable.rb +62 -0
  201. data/lib/easy_ml/support/synced_file.rb +103 -0
  202. data/lib/easy_ml/support/utc.rb +5 -1
  203. data/lib/easy_ml/support.rb +6 -3
  204. data/lib/easy_ml/version.rb +4 -1
  205. data/lib/easy_ml.rb +7 -2
  206. metadata +355 -72
  207. data/app/models/easy_ml/models.rb +0 -5
  208. data/lib/easy_ml/core/model.rb +0 -30
  209. data/lib/easy_ml/core/model_core.rb +0 -181
  210. data/lib/easy_ml/core/models/hyperparameters/base.rb +0 -34
  211. data/lib/easy_ml/core/models/hyperparameters/xgboost.rb +0 -19
  212. data/lib/easy_ml/core/models/xgboost.rb +0 -10
  213. data/lib/easy_ml/core/models/xgboost_core.rb +0 -220
  214. data/lib/easy_ml/core/models.rb +0 -10
  215. data/lib/easy_ml/core/uploaders/model_uploader.rb +0 -24
  216. data/lib/easy_ml/core/uploaders.rb +0 -7
  217. data/lib/easy_ml/data/dataloader.rb +0 -6
  218. data/lib/easy_ml/data/dataset/data/preprocessor/statistics.json +0 -31
  219. data/lib/easy_ml/data/dataset/data/sample_info.json +0 -1
  220. data/lib/easy_ml/data/dataset/dataset/files/sample_info.json +0 -1
  221. data/lib/easy_ml/data/dataset/splits/file_split.rb +0 -140
  222. data/lib/easy_ml/data/dataset/splits/in_memory_split.rb +0 -49
  223. data/lib/easy_ml/data/dataset/splits/split.rb +0 -98
  224. data/lib/easy_ml/data/dataset/splits.rb +0 -11
  225. data/lib/easy_ml/data/dataset/splitters/date_splitter.rb +0 -43
  226. data/lib/easy_ml/data/dataset/splitters.rb +0 -9
  227. data/lib/easy_ml/data/dataset.rb +0 -430
  228. data/lib/easy_ml/data/datasource/datasource_factory.rb +0 -60
  229. data/lib/easy_ml/data/datasource/file_datasource.rb +0 -40
  230. data/lib/easy_ml/data/datasource/merged_datasource.rb +0 -64
  231. data/lib/easy_ml/data/datasource/polars_datasource.rb +0 -41
  232. data/lib/easy_ml/data/datasource/s3_datasource.rb +0 -89
  233. data/lib/easy_ml/data/datasource.rb +0 -33
  234. data/lib/easy_ml/data/preprocessor/preprocessor.rb +0 -205
  235. data/lib/easy_ml/data/preprocessor/simple_imputer.rb +0 -402
  236. data/lib/easy_ml/deployment.rb +0 -5
  237. data/lib/easy_ml/support/synced_directory.rb +0 -134
  238. data/lib/easy_ml/transforms.rb +0 -29
  239. /data/{lib/easy_ml/core → app/models/easy_ml}/models/hyperparameters.rb +0 -0
@@ -1,10 +1,549 @@
1
- require_relative "../model"
2
-
1
+ # == Schema Information
2
+ #
3
+ # Table name: easy_ml_models
4
+ #
5
+ # id :bigint not null, primary key
6
+ # name :string not null
7
+ # model_type :string
8
+ # status :string
9
+ # dataset_id :bigint
10
+ # configuration :json
11
+ # version :string not null
12
+ # root_dir :string
13
+ # file :json
14
+ # created_at :datetime not null
15
+ # updated_at :datetime not null
3
16
  module EasyML
4
17
  module Models
5
- class XGBoost < EasyML::Model
6
- include GlueGun::DSL
7
- include EasyML::Core::Models::XGBoostCore
18
+ class XGBoost < BaseModel
19
+ Hyperparameters = EasyML::Models::Hyperparameters::XGBoost
20
+
21
+ OBJECTIVES = {
22
+ classification: {
23
+ binary: %w[binary:logistic binary:hinge],
24
+ multiclass: %w[multi:softmax multi:softprob],
25
+ },
26
+ regression: %w[reg:squarederror reg:logistic],
27
+ }
28
+
29
+ OBJECTIVES_FRONTEND = {
30
+ classification: [
31
+ { value: "binary:logistic", label: "Binary Logistic", description: "For binary classification" },
32
+ { value: "binary:hinge", label: "Binary Hinge", description: "For binary classification with hinge loss" },
33
+ { value: "multi:softmax", label: "Multiclass Softmax", description: "For multiclass classification" },
34
+ { value: "multi:softprob", label: "Multiclass Probability",
35
+ description: "For multiclass classification with probability output" },
36
+ ],
37
+ regression: [
38
+ { value: "reg:squarederror", label: "Squared Error", description: "For regression with squared loss" },
39
+ { value: "reg:logistic", label: "Logistic", description: "For regression with logistic loss" },
40
+ ],
41
+ }
42
+
43
+ add_configuration_attributes :early_stopping_rounds
44
+ attr_accessor :xgboost_model, :booster
45
+
46
+ def build_hyperparameters(params)
47
+ params = {} if params.nil?
48
+ return nil unless params.is_a?(Hash)
49
+
50
+ params.to_h.symbolize_keys!
51
+
52
+ params[:booster] = :gbtree unless params.key?(:booster)
53
+
54
+ klass = case params[:booster].to_sym
55
+ when :gbtree
56
+ Hyperparameters::GBTree
57
+ when :dart
58
+ Hyperparameters::Dart
59
+ when :gblinear
60
+ Hyperparameters::GBLinear
61
+ else
62
+ raise "Unknown booster type: #{booster}"
63
+ end
64
+ raise "Unknown booster type #{booster}" unless klass.present?
65
+
66
+ overrides = {
67
+ objective: model.objective,
68
+ }
69
+ params.merge!(overrides)
70
+
71
+ klass.new(params)
72
+ end
73
+
74
+ def add_auto_configurable_callbacks(params)
75
+ if EasyML::Configuration.wandb_api_key.present?
76
+ params.map!(&:deep_symbolize_keys)
77
+ unless params.any? { |c| c[:callback_type]&.to_sym == :wandb }
78
+ params << {
79
+ callback_type: :wandb,
80
+ project_name: model.name,
81
+ log_feature_importance: false,
82
+ define_metric: false,
83
+ }
84
+ end
85
+
86
+ unless params.any? { |c| c[:callback_type]&.to_sym == :evals_callback }
87
+ params << {
88
+ callback_type: :evals_callback,
89
+ }
90
+ end
91
+
92
+ unless params.any? { |c| c[:callback_type]&.to_sym == :progress_callback }
93
+ params << {
94
+ callback_type: :progress_callback,
95
+ }
96
+ end
97
+
98
+ params.sort_by! { |c| c[:callback_type] == :evals_callback ? 0 : 1 }
99
+ end
100
+ end
101
+
102
+ def build_callbacks(params)
103
+ return [] unless params.is_a?(Array)
104
+
105
+ add_auto_configurable_callbacks(params)
106
+
107
+ params.uniq! { |c| c[:callback_type] }
108
+
109
+ params.map do |conf|
110
+ conf.symbolize_keys!
111
+ if conf.key?(:callback_type)
112
+ callback_type = conf[:callback_type]
113
+ else
114
+ callback_type = conf.keys.first.to_sym
115
+ conf = conf.values.first.symbolize_keys!
116
+ end
117
+
118
+ klass = case callback_type.to_sym
119
+ when :wandb then Wandb::XGBoostCallback
120
+ when :evals_callback then EasyML::Models::XGBoost::EvalsCallback
121
+ when :progress_callback then EasyML::Models::XGBoost::ProgressCallback
122
+ end
123
+ raise "Unknown callback type #{callback_type}" unless klass.present?
124
+
125
+ klass.new(conf).tap do |instance|
126
+ instance.instance_variable_set(:@callback_type, callback_type)
127
+ instance.send(:model=, model) if instance.respond_to?(:model=)
128
+ end
129
+ end
130
+ end
131
+
132
+ def after_tuning
133
+ model.callbacks.each do |callback|
134
+ callback.after_tuning if callback.respond_to?(:after_tuning)
135
+ end
136
+ end
137
+
138
+ def prepare_callbacks(tuner)
139
+ set_wandb_project(tuner.project_name)
140
+
141
+ model.callbacks.each do |callback|
142
+ callback.prepare_callback(tuner) if callback.respond_to?(:prepare_callback)
143
+ end
144
+ end
145
+
146
+ def set_wandb_project(project_name)
147
+ wandb_callback = model.callbacks.detect { |cb| cb.class == Wandb::XGBoostCallback }
148
+ return unless wandb_callback.present?
149
+ wandb_callback.project_name = project_name
150
+ end
151
+
152
+ def get_wandb_project
153
+ wandb_callback = model.callbacks.detect { |cb| cb.class == Wandb::XGBoostCallback }
154
+ return nil unless wandb_callback.present?
155
+ wandb_callback.project_name
156
+ end
157
+
158
+ def delete_wandb_project
159
+ wandb_callback = model.callbacks.detect { |cb| cb.class == Wandb::XGBoostCallback }
160
+ return nil unless wandb_callback.present?
161
+ wandb_callback.project_name = nil
162
+ end
163
+
164
+ def is_fit?
165
+ @booster.present? && @booster.feature_names.any?
166
+ end
167
+
168
+ attr_accessor :progress_callback
169
+
170
+ def fit(tuning: false, x_train: nil, y_train: nil, x_valid: nil, y_valid: nil, &progress_block)
171
+ validate_objective
172
+
173
+ d_train, d_valid, = prepare_data if x_train.nil?
174
+
175
+ evals = [[d_train, "train"], [d_valid, "eval"]]
176
+ self.progress_callback = progress_block
177
+ set_default_wandb_project_name unless tuning
178
+ @booster = base_model.train(hyperparameters.to_h,
179
+ d_train,
180
+ evals: evals,
181
+ num_boost_round: hyperparameters["n_estimators"],
182
+ callbacks: model.callbacks,
183
+ early_stopping_rounds: hyperparameters.to_h.dig("early_stopping_rounds"))
184
+ delete_wandb_project unless tuning
185
+ return @booster
186
+ end
187
+
188
+ def set_default_wandb_project_name
189
+ return if get_wandb_project.present?
190
+
191
+ started_at = EasyML::Support::UTC.now
192
+ project_name = "#{model.name}_#{started_at.strftime("%Y_%m_%d_%H_%M_%S")}"
193
+ set_wandb_project(project_name)
194
+ end
195
+
196
+ def fit_in_batches(tuning: false, batch_size: 1024, batch_key: nil, batch_start: nil, batch_overlap: 1, checkpoint_dir: Rails.root.join("tmp", "xgboost_checkpoints"))
197
+ validate_objective
198
+ ensure_directory_exists(checkpoint_dir)
199
+ set_default_wandb_project_name unless tuning
200
+
201
+ # Prepare validation data
202
+ x_valid, y_valid = dataset.valid(split_ys: true)
203
+ d_valid = preprocess(x_valid, y_valid)
204
+
205
+ num_iterations = hyperparameters.to_h[:n_estimators]
206
+ early_stopping_rounds = hyperparameters.to_h[:early_stopping_rounds]
207
+
208
+ num_batches = dataset.train(batch_size: batch_size, batch_start: batch_start, batch_key: batch_key).count
209
+ iterations_per_batch = num_iterations / num_batches
210
+ stopping_points = (1..num_batches).to_a.map { |n| n * iterations_per_batch }
211
+ stopping_points[-1] = num_iterations
212
+
213
+ current_iteration = 0
214
+ current_batch = 0
215
+
216
+ callbacks = model.callbacks.nil? ? [] : model.callbacks.dup
217
+ callbacks << ::XGBoost::EvaluationMonitor.new(period: 1)
218
+
219
+ # Generate batches without loading full dataset
220
+ batches = dataset.train(split_ys: true, batch_size: batch_size, batch_start: batch_start, batch_key: batch_key)
221
+ prev_xs = []
222
+ prev_ys = []
223
+
224
+ while current_iteration < num_iterations
225
+ # Load the next batch
226
+ x_train, y_train = batches.next
227
+
228
+ # Add batch_overlap from previous batch if applicable
229
+ merged_x, merged_y = nil, nil
230
+ if prev_xs.any?
231
+ merged_x = Polars.concat([x_train] + prev_xs.flatten)
232
+ merged_y = Polars.concat([y_train] + prev_ys.flatten)
233
+ end
234
+
235
+ if batch_overlap > 0
236
+ prev_xs << [x_train]
237
+ prev_ys << [y_train]
238
+ if prev_xs.size > batch_overlap
239
+ prev_xs = prev_xs[1..]
240
+ prev_ys = prev_ys[1..]
241
+ end
242
+ end
243
+
244
+ if merged_x.present?
245
+ x_train = merged_x
246
+ y_train = merged_y
247
+ end
248
+
249
+ d_train = preprocess(x_train, y_train)
250
+ evals = [[d_train, "train"], [d_valid, "eval"]]
251
+
252
+ model_file = current_batch == 0 ? nil : checkpoint_dir.join("#{current_batch - 1}.json").to_s
253
+
254
+ @booster = booster_class.new(
255
+ params: hyperparameters.to_h.symbolize_keys,
256
+ cache: [d_train, d_valid],
257
+ model_file: model_file,
258
+ )
259
+ loop_callbacks = callbacks.dup
260
+ if early_stopping_rounds
261
+ loop_callbacks << ::XGBoost::EarlyStopping.new(rounds: early_stopping_rounds)
262
+ end
263
+ cb_container = ::XGBoost::CallbackContainer.new(loop_callbacks)
264
+ @booster = cb_container.before_training(@booster) if current_iteration == 0
265
+
266
+ stopping_point = stopping_points[current_batch]
267
+ while current_iteration < stopping_point
268
+ break if cb_container.before_iteration(@booster, current_iteration, d_train, evals)
269
+ @booster.update(d_train, current_iteration)
270
+ break if cb_container.after_iteration(@booster, current_iteration, d_train, evals)
271
+ current_iteration += 1
272
+ end
273
+ current_iteration = stopping_point # In case of early stopping
274
+
275
+ @booster.save_model(checkpoint_dir.join("#{current_batch}.json").to_s)
276
+ current_batch += 1
277
+ end
278
+
279
+ @booster = cb_container.after_training(@booster)
280
+ delete_wandb_project unless tuning
281
+ return @booster
282
+ end
283
+
284
+ def weights
285
+ @booster.save_model("tmp/xgboost_model.json")
286
+ @booster.get_dump
287
+ end
288
+
289
+ def predict(xs)
290
+ raise "No trained model! Train a model before calling predict" unless @booster.present?
291
+ raise "Cannot predict on nil — XGBoost" if xs.nil?
292
+
293
+ begin
294
+ y_pred = @booster.predict(preprocess(xs))
295
+ rescue StandardError => e
296
+ raise e unless e.message.match?(/Number of columns does not match/)
297
+
298
+ raise %(
299
+ >>>>><<<<<
300
+ XGBoost received predict with unexpected features!
301
+ >>>>><<<<<
302
+
303
+ Model expects features:
304
+ #{feature_names}
305
+ Model received features:
306
+ #{xs.columns}
307
+ )
308
+ end
309
+
310
+ case task.to_sym
311
+ when :classification
312
+ to_classification(y_pred)
313
+ else
314
+ y_pred
315
+ end
316
+ end
317
+
318
+ def predict_proba(data)
319
+ dmat = DMatrix.new(data)
320
+ y_pred = @booster.predict(dmat)
321
+
322
+ if y_pred.first.is_a?(Array)
323
+ # multiple classes
324
+ y_pred
325
+ else
326
+ y_pred.map { |v| [1 - v, v] }
327
+ end
328
+ end
329
+
330
+ def unload
331
+ @xgboost_model = nil
332
+ @booster = nil
333
+ end
334
+
335
+ def loaded?
336
+ @booster.present? && @booster.feature_names.any?
337
+ end
338
+
339
+ def load_model_file(path)
340
+ return if loaded?
341
+
342
+ initialize_model do
343
+ attrs = {
344
+ params: hyperparameters.to_h.symbolize_keys.compact,
345
+ model_file: path,
346
+ }.compact
347
+ booster_class.new(**attrs)
348
+ end
349
+ end
350
+
351
+ def external_model
352
+ @booster
353
+ end
354
+
355
+ def external_model=(booster)
356
+ @booster = booster
357
+ end
358
+
359
+ def model_changed?(prev_hash)
360
+ return false unless @booster.present? && @booster.feature_names.any?
361
+
362
+ current_model_hash = nil
363
+ Tempfile.create(["xgboost_model", ".json"]) do |tempfile|
364
+ @booster.save_model(tempfile.path)
365
+ tempfile.rewind
366
+ JSON.parse(tempfile.read)
367
+ current_model_hash = Digest::SHA256.file(tempfile.path).hexdigest
368
+ end
369
+ current_model_hash != prev_hash
370
+ end
371
+
372
+ def save_model_file(path)
373
+ path = path.to_s
374
+ ensure_directory_exists(File.dirname(path))
375
+ extension = Pathname.new(path).extname.gsub("\.", "")
376
+ path = "#{path}.json" unless extension == "json"
377
+
378
+ @booster.save_model(path)
379
+ path
380
+ end
381
+
382
+ def feature_names
383
+ @booster.feature_names
384
+ end
385
+
386
+ def feature_importances
387
+ score = @booster.score(importance_type: @importance_type || "gain")
388
+ scores = @booster.feature_names.map { |k| score[k] || 0.0 }
389
+ total = scores.sum.to_f
390
+ fi = scores.map { |s| s / total }
391
+ @booster.feature_names.zip(fi).to_h
392
+ end
393
+
394
+ def base_model
395
+ ::XGBoost
396
+ end
397
+
398
+ def prepare_data
399
+ if @d_train.nil?
400
+ x_sample, y_sample = dataset.train(split_ys: true, limit: 5)
401
+ preprocess(x_sample, y_sample) # Ensure we fail fast if the dataset is misconfigured
402
+ x_train, y_train = dataset.train(split_ys: true)
403
+ x_valid, y_valid = dataset.valid(split_ys: true)
404
+ x_test, y_test = dataset.test(split_ys: true)
405
+ @d_train = preprocess(x_train, y_train)
406
+ @d_valid = preprocess(x_valid, y_valid)
407
+ @d_test = preprocess(x_test, y_test)
408
+ end
409
+
410
+ [@d_train, @d_valid, @d_test]
411
+ end
412
+
413
+ def preprocess(xs, ys = nil)
414
+ return xs if xs.is_a?(::XGBoost::DMatrix)
415
+
416
+ orig_xs = xs.dup
417
+ column_names = xs.columns
418
+ xs = _preprocess(xs)
419
+ ys = ys.nil? ? nil : _preprocess(ys).flatten
420
+ kwargs = { label: ys }.compact
421
+ begin
422
+ ::XGBoost::DMatrix.new(xs, **kwargs).tap do |dmat|
423
+ dmat.feature_names = column_names
424
+ end
425
+ rescue StandardError => e
426
+ problematic_columns = orig_xs.schema.select { |k, v| [Polars::Categorical, Polars::String].include?(v) }
427
+ problematic_xs = orig_xs.select(problematic_columns.keys)
428
+ raise %(
429
+ Error building data for XGBoost.
430
+ Apply preprocessing to columns
431
+ >>>>><<<<<
432
+ #{problematic_columns.keys}
433
+ >>>>><<<<<
434
+ A sample of your dataset:
435
+ #{problematic_xs[0..5]}
436
+
437
+ #{if ys.present?
438
+ %(
439
+ This may also be due to your targets:
440
+ #{ys[0..5]}
441
+ )
442
+ else
443
+ ""
444
+ end}
445
+ )
446
+ end
447
+ end
448
+
449
+ def self.hyperparameter_constants
450
+ EasyML::Models::Hyperparameters::XGBoost.hyperparameter_constants
451
+ end
452
+
453
+ private
454
+
455
+ def booster_class
456
+ ::XGBoost::Booster
457
+ end
458
+
459
+ def d_matrix_class
460
+ ::XGBoost::DMatrix
461
+ end
462
+
463
+ def model_class
464
+ ::XGBoost::Model
465
+ end
466
+
467
+ def fit_batch(d_train, current_iteration, evals, cb_container)
468
+ if @booster.nil?
469
+ @booster = booster_class.new(params: @hyperparameters.to_h, cache: [d_train] + evals.map do |d|
470
+ d[0]
471
+ end, early_stopping_rounds: @hyperparameters.to_h.dig(:early_stopping_rounds))
472
+ end
473
+
474
+ @booster = cb_container.before_training(@booster)
475
+ cb_container.before_iteration(@booster, current_iteration, d_train, evals)
476
+ @booster.update(d_train, current_iteration)
477
+ cb_container.after_iteration(@booster, current_iteration, d_train, evals)
478
+ end
479
+
480
+ def _preprocess(df)
481
+ return df if df.is_a?(Array)
482
+
483
+ df.to_a.map do |row|
484
+ row.values.map do |value|
485
+ case value
486
+ when Time
487
+ value.to_i # Convert Time to Unix timestamp
488
+ when Date
489
+ value.to_time.to_i # Convert Date to Unix timestamp
490
+ when String
491
+ value
492
+ when TrueClass, FalseClass
493
+ value ? 1.0 : 0.0 # Convert booleans to 1.0 and 0.0
494
+ when Integer
495
+ value
496
+ else
497
+ value.to_f # Ensure everything else is converted to a float
498
+ end
499
+ end
500
+ end
501
+ end
502
+
503
+ def initialize_model
504
+ @xgboost_model = model_class.new(n_estimators: @hyperparameters.to_h.dig(:n_estimators))
505
+ if block_given?
506
+ @booster = yield
507
+ else
508
+ attrs = {
509
+ params: hyperparameters.to_h.symbolize_keys,
510
+ }.deep_compact
511
+ @booster = booster_class.new(**attrs)
512
+ end
513
+ @xgboost_model.instance_variable_set(:@booster, @booster)
514
+ end
515
+
516
+ def validate_objective
517
+ objective = hyperparameters.objective
518
+ unless task.present?
519
+ raise ArgumentError,
520
+ "cannot train model without task. Please specify either regression or classification (model.task = :regression)"
521
+ end
522
+
523
+ case task.to_sym
524
+ when :classification
525
+ _, ys = dataset.data(split_ys: true)
526
+ classification_type = ys[ys.columns.first].uniq.count <= 2 ? :binary : :multi_class
527
+ allowed_objectives = OBJECTIVES[:classification][classification_type]
528
+ else
529
+ allowed_objectives = OBJECTIVES[task.to_sym]
530
+ end
531
+ return if allowed_objectives.map(&:to_sym).include?(objective.to_sym)
532
+
533
+ raise ArgumentError,
534
+ "cannot use #{objective} for #{task} task. Allowed objectives are: #{allowed_objectives.join(", ")}"
535
+ end
536
+
537
+ def to_classification(y_pred)
538
+ if y_pred.first.is_a?(Array)
539
+ # multiple classes
540
+ y_pred.map do |v|
541
+ v.map.with_index.max_by { |v2, _| v2 }.last
542
+ end
543
+ else
544
+ y_pred.map { |v| v > 0.5 ? 1 : 0 }
545
+ end
546
+ end
8
547
  end
9
548
  end
10
549
  end
@@ -0,0 +1,44 @@
1
+ # == Schema Information
2
+ #
3
+ # Table name: easy_ml_predictions
4
+ #
5
+ # id :bigint not null, primary key
6
+ # model_id :bigint not null
7
+ # model_history_id :bigint
8
+ # prediction_type :string
9
+ # prediction_value :jsonb
10
+ # raw_input :jsonb
11
+ # normalized_input :jsonb
12
+ # created_at :datetime not null
13
+ # updated_at :datetime not null
14
+ #
15
+ module EasyML
16
+ class Prediction < ActiveRecord::Base
17
+ self.table_name = "easy_ml_predictions"
18
+
19
+ belongs_to :model
20
+ belongs_to :model_history, optional: true
21
+
22
+ validates :model_id, presence: true
23
+ validates :prediction_type, presence: true, inclusion: { in: %w[regression classification] }
24
+ validates :prediction_value, presence: true
25
+ validates :raw_input, presence: true
26
+ validates :normalized_input, presence: true
27
+
28
+ def prediction
29
+ prediction_value["value"]
30
+ end
31
+
32
+ def probabilities
33
+ prediction_value["probabilities"]
34
+ end
35
+
36
+ def regression?
37
+ prediction_type == "regression"
38
+ end
39
+
40
+ def classification?
41
+ prediction_type == "classification"
42
+ end
43
+ end
44
+ end