easy_ml 0.1.4 → 0.2.0.pre.rc1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (239) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +234 -26
  3. data/Rakefile +45 -0
  4. data/app/controllers/easy_ml/application_controller.rb +67 -0
  5. data/app/controllers/easy_ml/columns_controller.rb +38 -0
  6. data/app/controllers/easy_ml/datasets_controller.rb +156 -0
  7. data/app/controllers/easy_ml/datasources_controller.rb +88 -0
  8. data/app/controllers/easy_ml/deploys_controller.rb +20 -0
  9. data/app/controllers/easy_ml/models_controller.rb +151 -0
  10. data/app/controllers/easy_ml/retraining_runs_controller.rb +19 -0
  11. data/app/controllers/easy_ml/settings_controller.rb +59 -0
  12. data/app/frontend/components/AlertProvider.tsx +108 -0
  13. data/app/frontend/components/DatasetPreview.tsx +161 -0
  14. data/app/frontend/components/EmptyState.tsx +28 -0
  15. data/app/frontend/components/ModelCard.tsx +255 -0
  16. data/app/frontend/components/ModelDetails.tsx +334 -0
  17. data/app/frontend/components/ModelForm.tsx +384 -0
  18. data/app/frontend/components/Navigation.tsx +300 -0
  19. data/app/frontend/components/Pagination.tsx +72 -0
  20. data/app/frontend/components/Popover.tsx +55 -0
  21. data/app/frontend/components/PredictionStream.tsx +105 -0
  22. data/app/frontend/components/ScheduleModal.tsx +726 -0
  23. data/app/frontend/components/SearchInput.tsx +23 -0
  24. data/app/frontend/components/SearchableSelect.tsx +132 -0
  25. data/app/frontend/components/dataset/AutosaveIndicator.tsx +39 -0
  26. data/app/frontend/components/dataset/ColumnConfigModal.tsx +431 -0
  27. data/app/frontend/components/dataset/ColumnFilters.tsx +256 -0
  28. data/app/frontend/components/dataset/ColumnList.tsx +101 -0
  29. data/app/frontend/components/dataset/FeatureConfigPopover.tsx +57 -0
  30. data/app/frontend/components/dataset/FeaturePicker.tsx +205 -0
  31. data/app/frontend/components/dataset/PreprocessingConfig.tsx +704 -0
  32. data/app/frontend/components/dataset/SplitConfigurator.tsx +120 -0
  33. data/app/frontend/components/dataset/splitters/DateSplitter.tsx +58 -0
  34. data/app/frontend/components/dataset/splitters/KFoldSplitter.tsx +68 -0
  35. data/app/frontend/components/dataset/splitters/LeavePOutSplitter.tsx +29 -0
  36. data/app/frontend/components/dataset/splitters/PredefinedSplitter.tsx +146 -0
  37. data/app/frontend/components/dataset/splitters/RandomSplitter.tsx +85 -0
  38. data/app/frontend/components/dataset/splitters/StratifiedSplitter.tsx +79 -0
  39. data/app/frontend/components/dataset/splitters/constants.ts +77 -0
  40. data/app/frontend/components/dataset/splitters/types.ts +168 -0
  41. data/app/frontend/components/dataset/splitters/utils.ts +53 -0
  42. data/app/frontend/components/features/CodeEditor.tsx +46 -0
  43. data/app/frontend/components/features/DataPreview.tsx +150 -0
  44. data/app/frontend/components/features/FeatureCard.tsx +88 -0
  45. data/app/frontend/components/features/FeatureForm.tsx +235 -0
  46. data/app/frontend/components/features/FeatureGroupCard.tsx +54 -0
  47. data/app/frontend/components/settings/PluginSettings.tsx +81 -0
  48. data/app/frontend/components/ui/badge.tsx +44 -0
  49. data/app/frontend/components/ui/collapsible.tsx +9 -0
  50. data/app/frontend/components/ui/scroll-area.tsx +46 -0
  51. data/app/frontend/components/ui/separator.tsx +29 -0
  52. data/app/frontend/entrypoints/App.tsx +40 -0
  53. data/app/frontend/entrypoints/Application.tsx +24 -0
  54. data/app/frontend/hooks/useAutosave.ts +61 -0
  55. data/app/frontend/layouts/Layout.tsx +38 -0
  56. data/app/frontend/lib/utils.ts +6 -0
  57. data/app/frontend/mockData.ts +272 -0
  58. data/app/frontend/pages/DatasetDetailsPage.tsx +103 -0
  59. data/app/frontend/pages/DatasetsPage.tsx +261 -0
  60. data/app/frontend/pages/DatasourceFormPage.tsx +147 -0
  61. data/app/frontend/pages/DatasourcesPage.tsx +261 -0
  62. data/app/frontend/pages/EditModelPage.tsx +45 -0
  63. data/app/frontend/pages/EditTransformationPage.tsx +56 -0
  64. data/app/frontend/pages/ModelsPage.tsx +115 -0
  65. data/app/frontend/pages/NewDatasetPage.tsx +366 -0
  66. data/app/frontend/pages/NewModelPage.tsx +45 -0
  67. data/app/frontend/pages/NewTransformationPage.tsx +43 -0
  68. data/app/frontend/pages/SettingsPage.tsx +272 -0
  69. data/app/frontend/pages/ShowModelPage.tsx +30 -0
  70. data/app/frontend/pages/TransformationsPage.tsx +95 -0
  71. data/app/frontend/styles/application.css +100 -0
  72. data/app/frontend/types/dataset.ts +146 -0
  73. data/app/frontend/types/datasource.ts +33 -0
  74. data/app/frontend/types/preprocessing.ts +1 -0
  75. data/app/frontend/types.ts +113 -0
  76. data/app/helpers/easy_ml/application_helper.rb +10 -0
  77. data/app/jobs/easy_ml/application_job.rb +21 -0
  78. data/app/jobs/easy_ml/batch_job.rb +46 -0
  79. data/app/jobs/easy_ml/compute_feature_job.rb +19 -0
  80. data/app/jobs/easy_ml/deploy_job.rb +13 -0
  81. data/app/jobs/easy_ml/finalize_feature_job.rb +15 -0
  82. data/app/jobs/easy_ml/refresh_dataset_job.rb +32 -0
  83. data/app/jobs/easy_ml/schedule_retraining_job.rb +11 -0
  84. data/app/jobs/easy_ml/sync_datasource_job.rb +17 -0
  85. data/app/jobs/easy_ml/training_job.rb +62 -0
  86. data/app/models/easy_ml/adapters/base_adapter.rb +45 -0
  87. data/app/models/easy_ml/adapters/polars_adapter.rb +77 -0
  88. data/app/models/easy_ml/cleaner.rb +82 -0
  89. data/app/models/easy_ml/column.rb +124 -0
  90. data/app/models/easy_ml/column_history.rb +30 -0
  91. data/app/models/easy_ml/column_list.rb +122 -0
  92. data/app/models/easy_ml/concerns/configurable.rb +61 -0
  93. data/app/models/easy_ml/concerns/versionable.rb +19 -0
  94. data/app/models/easy_ml/dataset.rb +767 -0
  95. data/app/models/easy_ml/dataset_history.rb +56 -0
  96. data/app/models/easy_ml/datasource.rb +182 -0
  97. data/app/models/easy_ml/datasource_history.rb +24 -0
  98. data/app/models/easy_ml/datasources/base_datasource.rb +54 -0
  99. data/app/models/easy_ml/datasources/file_datasource.rb +58 -0
  100. data/app/models/easy_ml/datasources/polars_datasource.rb +89 -0
  101. data/app/models/easy_ml/datasources/s3_datasource.rb +97 -0
  102. data/app/models/easy_ml/deploy.rb +114 -0
  103. data/app/models/easy_ml/event.rb +79 -0
  104. data/app/models/easy_ml/feature.rb +437 -0
  105. data/app/models/easy_ml/feature_history.rb +38 -0
  106. data/app/models/easy_ml/model.rb +575 -41
  107. data/app/models/easy_ml/model_file.rb +133 -0
  108. data/app/models/easy_ml/model_file_history.rb +24 -0
  109. data/app/models/easy_ml/model_history.rb +51 -0
  110. data/app/models/easy_ml/models/base_model.rb +58 -0
  111. data/app/models/easy_ml/models/hyperparameters/base.rb +99 -0
  112. data/app/models/easy_ml/models/hyperparameters/xgboost/dart.rb +82 -0
  113. data/app/models/easy_ml/models/hyperparameters/xgboost/gblinear.rb +82 -0
  114. data/app/models/easy_ml/models/hyperparameters/xgboost/gbtree.rb +97 -0
  115. data/app/models/easy_ml/models/hyperparameters/xgboost.rb +71 -0
  116. data/app/models/easy_ml/models/xgboost/evals_callback.rb +138 -0
  117. data/app/models/easy_ml/models/xgboost/progress_callback.rb +39 -0
  118. data/app/models/easy_ml/models/xgboost.rb +544 -5
  119. data/app/models/easy_ml/prediction.rb +44 -0
  120. data/app/models/easy_ml/retraining_job.rb +278 -0
  121. data/app/models/easy_ml/retraining_run.rb +184 -0
  122. data/app/models/easy_ml/settings.rb +37 -0
  123. data/app/models/easy_ml/splitter.rb +90 -0
  124. data/app/models/easy_ml/splitters/base_splitter.rb +28 -0
  125. data/app/models/easy_ml/splitters/date_splitter.rb +91 -0
  126. data/app/models/easy_ml/splitters/predefined_splitter.rb +74 -0
  127. data/app/models/easy_ml/splitters/random_splitter.rb +82 -0
  128. data/app/models/easy_ml/tuner_job.rb +56 -0
  129. data/app/models/easy_ml/tuner_run.rb +31 -0
  130. data/app/models/splitter_history.rb +6 -0
  131. data/app/serializers/easy_ml/column_serializer.rb +27 -0
  132. data/app/serializers/easy_ml/dataset_serializer.rb +73 -0
  133. data/app/serializers/easy_ml/datasource_serializer.rb +64 -0
  134. data/app/serializers/easy_ml/feature_serializer.rb +27 -0
  135. data/app/serializers/easy_ml/model_serializer.rb +90 -0
  136. data/app/serializers/easy_ml/retraining_job_serializer.rb +22 -0
  137. data/app/serializers/easy_ml/retraining_run_serializer.rb +39 -0
  138. data/app/serializers/easy_ml/settings_serializer.rb +9 -0
  139. data/app/views/layouts/easy_ml/application.html.erb +15 -0
  140. data/config/initializers/resque.rb +3 -0
  141. data/config/resque-pool.yml +6 -0
  142. data/config/routes.rb +39 -0
  143. data/config/spring.rb +1 -0
  144. data/config/vite.json +15 -0
  145. data/lib/easy_ml/configuration.rb +64 -0
  146. data/lib/easy_ml/core/evaluators/base_evaluator.rb +53 -0
  147. data/lib/easy_ml/core/evaluators/classification_evaluators.rb +126 -0
  148. data/lib/easy_ml/core/evaluators/regression_evaluators.rb +66 -0
  149. data/lib/easy_ml/core/model_evaluator.rb +161 -89
  150. data/lib/easy_ml/core/tuner/adapters/base_adapter.rb +28 -18
  151. data/lib/easy_ml/core/tuner/adapters/xgboost_adapter.rb +4 -25
  152. data/lib/easy_ml/core/tuner.rb +123 -62
  153. data/lib/easy_ml/core.rb +0 -3
  154. data/lib/easy_ml/core_ext/hash.rb +24 -0
  155. data/lib/easy_ml/core_ext/pathname.rb +11 -5
  156. data/lib/easy_ml/data/date_converter.rb +90 -0
  157. data/lib/easy_ml/data/filter_extensions.rb +31 -0
  158. data/lib/easy_ml/data/polars_column.rb +126 -0
  159. data/lib/easy_ml/data/polars_reader.rb +297 -0
  160. data/lib/easy_ml/data/preprocessor.rb +280 -142
  161. data/lib/easy_ml/data/simple_imputer.rb +255 -0
  162. data/lib/easy_ml/data/splits/file_split.rb +252 -0
  163. data/lib/easy_ml/data/splits/in_memory_split.rb +54 -0
  164. data/lib/easy_ml/data/splits/split.rb +95 -0
  165. data/lib/easy_ml/data/splits.rb +9 -0
  166. data/lib/easy_ml/data/statistics_learner.rb +93 -0
  167. data/lib/easy_ml/data/synced_directory.rb +341 -0
  168. data/lib/easy_ml/data.rb +6 -2
  169. data/lib/easy_ml/engine.rb +105 -6
  170. data/lib/easy_ml/feature_store.rb +227 -0
  171. data/lib/easy_ml/features.rb +61 -0
  172. data/lib/easy_ml/initializers/inflections.rb +17 -3
  173. data/lib/easy_ml/logging.rb +2 -2
  174. data/lib/easy_ml/predict.rb +74 -0
  175. data/lib/easy_ml/railtie/generators/migration/migration_generator.rb +192 -36
  176. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_column_histories.rb.tt +9 -0
  177. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_columns.rb.tt +25 -0
  178. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_dataset_histories.rb.tt +9 -0
  179. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_datasets.rb.tt +31 -0
  180. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_datasource_histories.rb.tt +9 -0
  181. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_datasources.rb.tt +16 -0
  182. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_deploys.rb.tt +24 -0
  183. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_events.rb.tt +20 -0
  184. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_feature_histories.rb.tt +14 -0
  185. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_features.rb.tt +32 -0
  186. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_model_file_histories.rb.tt +9 -0
  187. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_model_files.rb.tt +17 -0
  188. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_model_histories.rb.tt +9 -0
  189. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_models.rb.tt +20 -9
  190. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_predictions.rb.tt +17 -0
  191. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_retraining_jobs.rb.tt +77 -0
  192. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_settings.rb.tt +9 -0
  193. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_splitter_histories.rb.tt +9 -0
  194. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_splitters.rb.tt +15 -0
  195. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_tuner_jobs.rb.tt +40 -0
  196. data/lib/easy_ml/support/est.rb +5 -1
  197. data/lib/easy_ml/support/file_rotate.rb +79 -15
  198. data/lib/easy_ml/support/file_support.rb +9 -0
  199. data/lib/easy_ml/support/local_file.rb +24 -0
  200. data/lib/easy_ml/support/lockable.rb +62 -0
  201. data/lib/easy_ml/support/synced_file.rb +103 -0
  202. data/lib/easy_ml/support/utc.rb +5 -1
  203. data/lib/easy_ml/support.rb +6 -3
  204. data/lib/easy_ml/version.rb +4 -1
  205. data/lib/easy_ml.rb +7 -2
  206. metadata +355 -72
  207. data/app/models/easy_ml/models.rb +0 -5
  208. data/lib/easy_ml/core/model.rb +0 -30
  209. data/lib/easy_ml/core/model_core.rb +0 -181
  210. data/lib/easy_ml/core/models/hyperparameters/base.rb +0 -34
  211. data/lib/easy_ml/core/models/hyperparameters/xgboost.rb +0 -19
  212. data/lib/easy_ml/core/models/xgboost.rb +0 -10
  213. data/lib/easy_ml/core/models/xgboost_core.rb +0 -220
  214. data/lib/easy_ml/core/models.rb +0 -10
  215. data/lib/easy_ml/core/uploaders/model_uploader.rb +0 -24
  216. data/lib/easy_ml/core/uploaders.rb +0 -7
  217. data/lib/easy_ml/data/dataloader.rb +0 -6
  218. data/lib/easy_ml/data/dataset/data/preprocessor/statistics.json +0 -31
  219. data/lib/easy_ml/data/dataset/data/sample_info.json +0 -1
  220. data/lib/easy_ml/data/dataset/dataset/files/sample_info.json +0 -1
  221. data/lib/easy_ml/data/dataset/splits/file_split.rb +0 -140
  222. data/lib/easy_ml/data/dataset/splits/in_memory_split.rb +0 -49
  223. data/lib/easy_ml/data/dataset/splits/split.rb +0 -98
  224. data/lib/easy_ml/data/dataset/splits.rb +0 -11
  225. data/lib/easy_ml/data/dataset/splitters/date_splitter.rb +0 -43
  226. data/lib/easy_ml/data/dataset/splitters.rb +0 -9
  227. data/lib/easy_ml/data/dataset.rb +0 -430
  228. data/lib/easy_ml/data/datasource/datasource_factory.rb +0 -60
  229. data/lib/easy_ml/data/datasource/file_datasource.rb +0 -40
  230. data/lib/easy_ml/data/datasource/merged_datasource.rb +0 -64
  231. data/lib/easy_ml/data/datasource/polars_datasource.rb +0 -41
  232. data/lib/easy_ml/data/datasource/s3_datasource.rb +0 -89
  233. data/lib/easy_ml/data/datasource.rb +0 -33
  234. data/lib/easy_ml/data/preprocessor/preprocessor.rb +0 -205
  235. data/lib/easy_ml/data/preprocessor/simple_imputer.rb +0 -402
  236. data/lib/easy_ml/deployment.rb +0 -5
  237. data/lib/easy_ml/support/synced_directory.rb +0 -134
  238. data/lib/easy_ml/transforms.rb +0 -29
  239. /data/{lib/easy_ml/core → app/models/easy_ml}/models/hyperparameters.rb +0 -0
@@ -1,10 +1,549 @@
1
- require_relative "../model"
2
-
1
+ # == Schema Information
2
+ #
3
+ # Table name: easy_ml_models
4
+ #
5
+ # id :bigint not null, primary key
6
+ # name :string not null
7
+ # model_type :string
8
+ # status :string
9
+ # dataset_id :bigint
10
+ # configuration :json
11
+ # version :string not null
12
+ # root_dir :string
13
+ # file :json
14
+ # created_at :datetime not null
15
+ # updated_at :datetime not null
3
16
  module EasyML
4
17
  module Models
5
- class XGBoost < EasyML::Model
6
- include GlueGun::DSL
7
- include EasyML::Core::Models::XGBoostCore
18
+ class XGBoost < BaseModel
19
+ Hyperparameters = EasyML::Models::Hyperparameters::XGBoost
20
+
21
+ OBJECTIVES = {
22
+ classification: {
23
+ binary: %w[binary:logistic binary:hinge],
24
+ multiclass: %w[multi:softmax multi:softprob],
25
+ },
26
+ regression: %w[reg:squarederror reg:logistic],
27
+ }
28
+
29
+ OBJECTIVES_FRONTEND = {
30
+ classification: [
31
+ { value: "binary:logistic", label: "Binary Logistic", description: "For binary classification" },
32
+ { value: "binary:hinge", label: "Binary Hinge", description: "For binary classification with hinge loss" },
33
+ { value: "multi:softmax", label: "Multiclass Softmax", description: "For multiclass classification" },
34
+ { value: "multi:softprob", label: "Multiclass Probability",
35
+ description: "For multiclass classification with probability output" },
36
+ ],
37
+ regression: [
38
+ { value: "reg:squarederror", label: "Squared Error", description: "For regression with squared loss" },
39
+ { value: "reg:logistic", label: "Logistic", description: "For regression with logistic loss" },
40
+ ],
41
+ }
42
+
43
+ add_configuration_attributes :early_stopping_rounds
44
+ attr_accessor :xgboost_model, :booster
45
+
46
+ def build_hyperparameters(params)
47
+ params = {} if params.nil?
48
+ return nil unless params.is_a?(Hash)
49
+
50
+ params.to_h.symbolize_keys!
51
+
52
+ params[:booster] = :gbtree unless params.key?(:booster)
53
+
54
+ klass = case params[:booster].to_sym
55
+ when :gbtree
56
+ Hyperparameters::GBTree
57
+ when :dart
58
+ Hyperparameters::Dart
59
+ when :gblinear
60
+ Hyperparameters::GBLinear
61
+ else
62
+ raise "Unknown booster type: #{booster}"
63
+ end
64
+ raise "Unknown booster type #{booster}" unless klass.present?
65
+
66
+ overrides = {
67
+ objective: model.objective,
68
+ }
69
+ params.merge!(overrides)
70
+
71
+ klass.new(params)
72
+ end
73
+
74
+ def add_auto_configurable_callbacks(params)
75
+ if EasyML::Configuration.wandb_api_key.present?
76
+ params.map!(&:deep_symbolize_keys)
77
+ unless params.any? { |c| c[:callback_type]&.to_sym == :wandb }
78
+ params << {
79
+ callback_type: :wandb,
80
+ project_name: model.name,
81
+ log_feature_importance: false,
82
+ define_metric: false,
83
+ }
84
+ end
85
+
86
+ unless params.any? { |c| c[:callback_type]&.to_sym == :evals_callback }
87
+ params << {
88
+ callback_type: :evals_callback,
89
+ }
90
+ end
91
+
92
+ unless params.any? { |c| c[:callback_type]&.to_sym == :progress_callback }
93
+ params << {
94
+ callback_type: :progress_callback,
95
+ }
96
+ end
97
+
98
+ params.sort_by! { |c| c[:callback_type] == :evals_callback ? 0 : 1 }
99
+ end
100
+ end
101
+
102
+ def build_callbacks(params)
103
+ return [] unless params.is_a?(Array)
104
+
105
+ add_auto_configurable_callbacks(params)
106
+
107
+ params.uniq! { |c| c[:callback_type] }
108
+
109
+ params.map do |conf|
110
+ conf.symbolize_keys!
111
+ if conf.key?(:callback_type)
112
+ callback_type = conf[:callback_type]
113
+ else
114
+ callback_type = conf.keys.first.to_sym
115
+ conf = conf.values.first.symbolize_keys!
116
+ end
117
+
118
+ klass = case callback_type.to_sym
119
+ when :wandb then Wandb::XGBoostCallback
120
+ when :evals_callback then EasyML::Models::XGBoost::EvalsCallback
121
+ when :progress_callback then EasyML::Models::XGBoost::ProgressCallback
122
+ end
123
+ raise "Unknown callback type #{callback_type}" unless klass.present?
124
+
125
+ klass.new(conf).tap do |instance|
126
+ instance.instance_variable_set(:@callback_type, callback_type)
127
+ instance.send(:model=, model) if instance.respond_to?(:model=)
128
+ end
129
+ end
130
+ end
131
+
132
+ def after_tuning
133
+ model.callbacks.each do |callback|
134
+ callback.after_tuning if callback.respond_to?(:after_tuning)
135
+ end
136
+ end
137
+
138
+ def prepare_callbacks(tuner)
139
+ set_wandb_project(tuner.project_name)
140
+
141
+ model.callbacks.each do |callback|
142
+ callback.prepare_callback(tuner) if callback.respond_to?(:prepare_callback)
143
+ end
144
+ end
145
+
146
+ def set_wandb_project(project_name)
147
+ wandb_callback = model.callbacks.detect { |cb| cb.class == Wandb::XGBoostCallback }
148
+ return unless wandb_callback.present?
149
+ wandb_callback.project_name = project_name
150
+ end
151
+
152
+ def get_wandb_project
153
+ wandb_callback = model.callbacks.detect { |cb| cb.class == Wandb::XGBoostCallback }
154
+ return nil unless wandb_callback.present?
155
+ wandb_callback.project_name
156
+ end
157
+
158
+ def delete_wandb_project
159
+ wandb_callback = model.callbacks.detect { |cb| cb.class == Wandb::XGBoostCallback }
160
+ return nil unless wandb_callback.present?
161
+ wandb_callback.project_name = nil
162
+ end
163
+
164
+ def is_fit?
165
+ @booster.present? && @booster.feature_names.any?
166
+ end
167
+
168
+ attr_accessor :progress_callback
169
+
170
+ def fit(tuning: false, x_train: nil, y_train: nil, x_valid: nil, y_valid: nil, &progress_block)
171
+ validate_objective
172
+
173
+ d_train, d_valid, = prepare_data if x_train.nil?
174
+
175
+ evals = [[d_train, "train"], [d_valid, "eval"]]
176
+ self.progress_callback = progress_block
177
+ set_default_wandb_project_name unless tuning
178
+ @booster = base_model.train(hyperparameters.to_h,
179
+ d_train,
180
+ evals: evals,
181
+ num_boost_round: hyperparameters["n_estimators"],
182
+ callbacks: model.callbacks,
183
+ early_stopping_rounds: hyperparameters.to_h.dig("early_stopping_rounds"))
184
+ delete_wandb_project unless tuning
185
+ return @booster
186
+ end
187
+
188
+ def set_default_wandb_project_name
189
+ return if get_wandb_project.present?
190
+
191
+ started_at = EasyML::Support::UTC.now
192
+ project_name = "#{model.name}_#{started_at.strftime("%Y_%m_%d_%H_%M_%S")}"
193
+ set_wandb_project(project_name)
194
+ end
195
+
196
+ def fit_in_batches(tuning: false, batch_size: 1024, batch_key: nil, batch_start: nil, batch_overlap: 1, checkpoint_dir: Rails.root.join("tmp", "xgboost_checkpoints"))
197
+ validate_objective
198
+ ensure_directory_exists(checkpoint_dir)
199
+ set_default_wandb_project_name unless tuning
200
+
201
+ # Prepare validation data
202
+ x_valid, y_valid = dataset.valid(split_ys: true)
203
+ d_valid = preprocess(x_valid, y_valid)
204
+
205
+ num_iterations = hyperparameters.to_h[:n_estimators]
206
+ early_stopping_rounds = hyperparameters.to_h[:early_stopping_rounds]
207
+
208
+ num_batches = dataset.train(batch_size: batch_size, batch_start: batch_start, batch_key: batch_key).count
209
+ iterations_per_batch = num_iterations / num_batches
210
+ stopping_points = (1..num_batches).to_a.map { |n| n * iterations_per_batch }
211
+ stopping_points[-1] = num_iterations
212
+
213
+ current_iteration = 0
214
+ current_batch = 0
215
+
216
+ callbacks = model.callbacks.nil? ? [] : model.callbacks.dup
217
+ callbacks << ::XGBoost::EvaluationMonitor.new(period: 1)
218
+
219
+ # Generate batches without loading full dataset
220
+ batches = dataset.train(split_ys: true, batch_size: batch_size, batch_start: batch_start, batch_key: batch_key)
221
+ prev_xs = []
222
+ prev_ys = []
223
+
224
+ while current_iteration < num_iterations
225
+ # Load the next batch
226
+ x_train, y_train = batches.next
227
+
228
+ # Add batch_overlap from previous batch if applicable
229
+ merged_x, merged_y = nil, nil
230
+ if prev_xs.any?
231
+ merged_x = Polars.concat([x_train] + prev_xs.flatten)
232
+ merged_y = Polars.concat([y_train] + prev_ys.flatten)
233
+ end
234
+
235
+ if batch_overlap > 0
236
+ prev_xs << [x_train]
237
+ prev_ys << [y_train]
238
+ if prev_xs.size > batch_overlap
239
+ prev_xs = prev_xs[1..]
240
+ prev_ys = prev_ys[1..]
241
+ end
242
+ end
243
+
244
+ if merged_x.present?
245
+ x_train = merged_x
246
+ y_train = merged_y
247
+ end
248
+
249
+ d_train = preprocess(x_train, y_train)
250
+ evals = [[d_train, "train"], [d_valid, "eval"]]
251
+
252
+ model_file = current_batch == 0 ? nil : checkpoint_dir.join("#{current_batch - 1}.json").to_s
253
+
254
+ @booster = booster_class.new(
255
+ params: hyperparameters.to_h.symbolize_keys,
256
+ cache: [d_train, d_valid],
257
+ model_file: model_file,
258
+ )
259
+ loop_callbacks = callbacks.dup
260
+ if early_stopping_rounds
261
+ loop_callbacks << ::XGBoost::EarlyStopping.new(rounds: early_stopping_rounds)
262
+ end
263
+ cb_container = ::XGBoost::CallbackContainer.new(loop_callbacks)
264
+ @booster = cb_container.before_training(@booster) if current_iteration == 0
265
+
266
+ stopping_point = stopping_points[current_batch]
267
+ while current_iteration < stopping_point
268
+ break if cb_container.before_iteration(@booster, current_iteration, d_train, evals)
269
+ @booster.update(d_train, current_iteration)
270
+ break if cb_container.after_iteration(@booster, current_iteration, d_train, evals)
271
+ current_iteration += 1
272
+ end
273
+ current_iteration = stopping_point # In case of early stopping
274
+
275
+ @booster.save_model(checkpoint_dir.join("#{current_batch}.json").to_s)
276
+ current_batch += 1
277
+ end
278
+
279
+ @booster = cb_container.after_training(@booster)
280
+ delete_wandb_project unless tuning
281
+ return @booster
282
+ end
283
+
284
+ def weights
285
+ @booster.save_model("tmp/xgboost_model.json")
286
+ @booster.get_dump
287
+ end
288
+
289
+ def predict(xs)
290
+ raise "No trained model! Train a model before calling predict" unless @booster.present?
291
+ raise "Cannot predict on nil — XGBoost" if xs.nil?
292
+
293
+ begin
294
+ y_pred = @booster.predict(preprocess(xs))
295
+ rescue StandardError => e
296
+ raise e unless e.message.match?(/Number of columns does not match/)
297
+
298
+ raise %(
299
+ >>>>><<<<<
300
+ XGBoost received predict with unexpected features!
301
+ >>>>><<<<<
302
+
303
+ Model expects features:
304
+ #{feature_names}
305
+ Model received features:
306
+ #{xs.columns}
307
+ )
308
+ end
309
+
310
+ case task.to_sym
311
+ when :classification
312
+ to_classification(y_pred)
313
+ else
314
+ y_pred
315
+ end
316
+ end
317
+
318
+ def predict_proba(data)
319
+ dmat = DMatrix.new(data)
320
+ y_pred = @booster.predict(dmat)
321
+
322
+ if y_pred.first.is_a?(Array)
323
+ # multiple classes
324
+ y_pred
325
+ else
326
+ y_pred.map { |v| [1 - v, v] }
327
+ end
328
+ end
329
+
330
+ def unload
331
+ @xgboost_model = nil
332
+ @booster = nil
333
+ end
334
+
335
+ def loaded?
336
+ @booster.present? && @booster.feature_names.any?
337
+ end
338
+
339
+ def load_model_file(path)
340
+ return if loaded?
341
+
342
+ initialize_model do
343
+ attrs = {
344
+ params: hyperparameters.to_h.symbolize_keys.compact,
345
+ model_file: path,
346
+ }.compact
347
+ booster_class.new(**attrs)
348
+ end
349
+ end
350
+
351
+ def external_model
352
+ @booster
353
+ end
354
+
355
+ def external_model=(booster)
356
+ @booster = booster
357
+ end
358
+
359
+ def model_changed?(prev_hash)
360
+ return false unless @booster.present? && @booster.feature_names.any?
361
+
362
+ current_model_hash = nil
363
+ Tempfile.create(["xgboost_model", ".json"]) do |tempfile|
364
+ @booster.save_model(tempfile.path)
365
+ tempfile.rewind
366
+ JSON.parse(tempfile.read)
367
+ current_model_hash = Digest::SHA256.file(tempfile.path).hexdigest
368
+ end
369
+ current_model_hash != prev_hash
370
+ end
371
+
372
+ def save_model_file(path)
373
+ path = path.to_s
374
+ ensure_directory_exists(File.dirname(path))
375
+ extension = Pathname.new(path).extname.gsub("\.", "")
376
+ path = "#{path}.json" unless extension == "json"
377
+
378
+ @booster.save_model(path)
379
+ path
380
+ end
381
+
382
+ def feature_names
383
+ @booster.feature_names
384
+ end
385
+
386
+ def feature_importances
387
+ score = @booster.score(importance_type: @importance_type || "gain")
388
+ scores = @booster.feature_names.map { |k| score[k] || 0.0 }
389
+ total = scores.sum.to_f
390
+ fi = scores.map { |s| s / total }
391
+ @booster.feature_names.zip(fi).to_h
392
+ end
393
+
394
+ def base_model
395
+ ::XGBoost
396
+ end
397
+
398
+ def prepare_data
399
+ if @d_train.nil?
400
+ x_sample, y_sample = dataset.train(split_ys: true, limit: 5)
401
+ preprocess(x_sample, y_sample) # Ensure we fail fast if the dataset is misconfigured
402
+ x_train, y_train = dataset.train(split_ys: true)
403
+ x_valid, y_valid = dataset.valid(split_ys: true)
404
+ x_test, y_test = dataset.test(split_ys: true)
405
+ @d_train = preprocess(x_train, y_train)
406
+ @d_valid = preprocess(x_valid, y_valid)
407
+ @d_test = preprocess(x_test, y_test)
408
+ end
409
+
410
+ [@d_train, @d_valid, @d_test]
411
+ end
412
+
413
+ def preprocess(xs, ys = nil)
414
+ return xs if xs.is_a?(::XGBoost::DMatrix)
415
+
416
+ orig_xs = xs.dup
417
+ column_names = xs.columns
418
+ xs = _preprocess(xs)
419
+ ys = ys.nil? ? nil : _preprocess(ys).flatten
420
+ kwargs = { label: ys }.compact
421
+ begin
422
+ ::XGBoost::DMatrix.new(xs, **kwargs).tap do |dmat|
423
+ dmat.feature_names = column_names
424
+ end
425
+ rescue StandardError => e
426
+ problematic_columns = orig_xs.schema.select { |k, v| [Polars::Categorical, Polars::String].include?(v) }
427
+ problematic_xs = orig_xs.select(problematic_columns.keys)
428
+ raise %(
429
+ Error building data for XGBoost.
430
+ Apply preprocessing to columns
431
+ >>>>><<<<<
432
+ #{problematic_columns.keys}
433
+ >>>>><<<<<
434
+ A sample of your dataset:
435
+ #{problematic_xs[0..5]}
436
+
437
+ #{if ys.present?
438
+ %(
439
+ This may also be due to your targets:
440
+ #{ys[0..5]}
441
+ )
442
+ else
443
+ ""
444
+ end}
445
+ )
446
+ end
447
+ end
448
+
449
+ def self.hyperparameter_constants
450
+ EasyML::Models::Hyperparameters::XGBoost.hyperparameter_constants
451
+ end
452
+
453
+ private
454
+
455
+ def booster_class
456
+ ::XGBoost::Booster
457
+ end
458
+
459
+ def d_matrix_class
460
+ ::XGBoost::DMatrix
461
+ end
462
+
463
+ def model_class
464
+ ::XGBoost::Model
465
+ end
466
+
467
+ def fit_batch(d_train, current_iteration, evals, cb_container)
468
+ if @booster.nil?
469
+ @booster = booster_class.new(params: @hyperparameters.to_h, cache: [d_train] + evals.map do |d|
470
+ d[0]
471
+ end, early_stopping_rounds: @hyperparameters.to_h.dig(:early_stopping_rounds))
472
+ end
473
+
474
+ @booster = cb_container.before_training(@booster)
475
+ cb_container.before_iteration(@booster, current_iteration, d_train, evals)
476
+ @booster.update(d_train, current_iteration)
477
+ cb_container.after_iteration(@booster, current_iteration, d_train, evals)
478
+ end
479
+
480
+ def _preprocess(df)
481
+ return df if df.is_a?(Array)
482
+
483
+ df.to_a.map do |row|
484
+ row.values.map do |value|
485
+ case value
486
+ when Time
487
+ value.to_i # Convert Time to Unix timestamp
488
+ when Date
489
+ value.to_time.to_i # Convert Date to Unix timestamp
490
+ when String
491
+ value
492
+ when TrueClass, FalseClass
493
+ value ? 1.0 : 0.0 # Convert booleans to 1.0 and 0.0
494
+ when Integer
495
+ value
496
+ else
497
+ value.to_f # Ensure everything else is converted to a float
498
+ end
499
+ end
500
+ end
501
+ end
502
+
503
+ def initialize_model
504
+ @xgboost_model = model_class.new(n_estimators: @hyperparameters.to_h.dig(:n_estimators))
505
+ if block_given?
506
+ @booster = yield
507
+ else
508
+ attrs = {
509
+ params: hyperparameters.to_h.symbolize_keys,
510
+ }.deep_compact
511
+ @booster = booster_class.new(**attrs)
512
+ end
513
+ @xgboost_model.instance_variable_set(:@booster, @booster)
514
+ end
515
+
516
+ def validate_objective
517
+ objective = hyperparameters.objective
518
+ unless task.present?
519
+ raise ArgumentError,
520
+ "cannot train model without task. Please specify either regression or classification (model.task = :regression)"
521
+ end
522
+
523
+ case task.to_sym
524
+ when :classification
525
+ _, ys = dataset.data(split_ys: true)
526
+ classification_type = ys[ys.columns.first].uniq.count <= 2 ? :binary : :multi_class
527
+ allowed_objectives = OBJECTIVES[:classification][classification_type]
528
+ else
529
+ allowed_objectives = OBJECTIVES[task.to_sym]
530
+ end
531
+ return if allowed_objectives.map(&:to_sym).include?(objective.to_sym)
532
+
533
+ raise ArgumentError,
534
+ "cannot use #{objective} for #{task} task. Allowed objectives are: #{allowed_objectives.join(", ")}"
535
+ end
536
+
537
+ def to_classification(y_pred)
538
+ if y_pred.first.is_a?(Array)
539
+ # multiple classes
540
+ y_pred.map do |v|
541
+ v.map.with_index.max_by { |v2, _| v2 }.last
542
+ end
543
+ else
544
+ y_pred.map { |v| v > 0.5 ? 1 : 0 }
545
+ end
546
+ end
8
547
  end
9
548
  end
10
549
  end
@@ -0,0 +1,44 @@
1
+ # == Schema Information
2
+ #
3
+ # Table name: easy_ml_predictions
4
+ #
5
+ # id :bigint not null, primary key
6
+ # model_id :bigint not null
7
+ # model_history_id :bigint
8
+ # prediction_type :string
9
+ # prediction_value :jsonb
10
+ # raw_input :jsonb
11
+ # normalized_input :jsonb
12
+ # created_at :datetime not null
13
+ # updated_at :datetime not null
14
+ #
15
+ module EasyML
16
+ class Prediction < ActiveRecord::Base
17
+ self.table_name = "easy_ml_predictions"
18
+
19
+ belongs_to :model
20
+ belongs_to :model_history, optional: true
21
+
22
+ validates :model_id, presence: true
23
+ validates :prediction_type, presence: true, inclusion: { in: %w[regression classification] }
24
+ validates :prediction_value, presence: true
25
+ validates :raw_input, presence: true
26
+ validates :normalized_input, presence: true
27
+
28
+ def prediction
29
+ prediction_value["value"]
30
+ end
31
+
32
+ def probabilities
33
+ prediction_value["probabilities"]
34
+ end
35
+
36
+ def regression?
37
+ prediction_type == "regression"
38
+ end
39
+
40
+ def classification?
41
+ prediction_type == "classification"
42
+ end
43
+ end
44
+ end