easy_ml 0.1.4 → 0.2.0.pre.rc1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (239) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +234 -26
  3. data/Rakefile +45 -0
  4. data/app/controllers/easy_ml/application_controller.rb +67 -0
  5. data/app/controllers/easy_ml/columns_controller.rb +38 -0
  6. data/app/controllers/easy_ml/datasets_controller.rb +156 -0
  7. data/app/controllers/easy_ml/datasources_controller.rb +88 -0
  8. data/app/controllers/easy_ml/deploys_controller.rb +20 -0
  9. data/app/controllers/easy_ml/models_controller.rb +151 -0
  10. data/app/controllers/easy_ml/retraining_runs_controller.rb +19 -0
  11. data/app/controllers/easy_ml/settings_controller.rb +59 -0
  12. data/app/frontend/components/AlertProvider.tsx +108 -0
  13. data/app/frontend/components/DatasetPreview.tsx +161 -0
  14. data/app/frontend/components/EmptyState.tsx +28 -0
  15. data/app/frontend/components/ModelCard.tsx +255 -0
  16. data/app/frontend/components/ModelDetails.tsx +334 -0
  17. data/app/frontend/components/ModelForm.tsx +384 -0
  18. data/app/frontend/components/Navigation.tsx +300 -0
  19. data/app/frontend/components/Pagination.tsx +72 -0
  20. data/app/frontend/components/Popover.tsx +55 -0
  21. data/app/frontend/components/PredictionStream.tsx +105 -0
  22. data/app/frontend/components/ScheduleModal.tsx +726 -0
  23. data/app/frontend/components/SearchInput.tsx +23 -0
  24. data/app/frontend/components/SearchableSelect.tsx +132 -0
  25. data/app/frontend/components/dataset/AutosaveIndicator.tsx +39 -0
  26. data/app/frontend/components/dataset/ColumnConfigModal.tsx +431 -0
  27. data/app/frontend/components/dataset/ColumnFilters.tsx +256 -0
  28. data/app/frontend/components/dataset/ColumnList.tsx +101 -0
  29. data/app/frontend/components/dataset/FeatureConfigPopover.tsx +57 -0
  30. data/app/frontend/components/dataset/FeaturePicker.tsx +205 -0
  31. data/app/frontend/components/dataset/PreprocessingConfig.tsx +704 -0
  32. data/app/frontend/components/dataset/SplitConfigurator.tsx +120 -0
  33. data/app/frontend/components/dataset/splitters/DateSplitter.tsx +58 -0
  34. data/app/frontend/components/dataset/splitters/KFoldSplitter.tsx +68 -0
  35. data/app/frontend/components/dataset/splitters/LeavePOutSplitter.tsx +29 -0
  36. data/app/frontend/components/dataset/splitters/PredefinedSplitter.tsx +146 -0
  37. data/app/frontend/components/dataset/splitters/RandomSplitter.tsx +85 -0
  38. data/app/frontend/components/dataset/splitters/StratifiedSplitter.tsx +79 -0
  39. data/app/frontend/components/dataset/splitters/constants.ts +77 -0
  40. data/app/frontend/components/dataset/splitters/types.ts +168 -0
  41. data/app/frontend/components/dataset/splitters/utils.ts +53 -0
  42. data/app/frontend/components/features/CodeEditor.tsx +46 -0
  43. data/app/frontend/components/features/DataPreview.tsx +150 -0
  44. data/app/frontend/components/features/FeatureCard.tsx +88 -0
  45. data/app/frontend/components/features/FeatureForm.tsx +235 -0
  46. data/app/frontend/components/features/FeatureGroupCard.tsx +54 -0
  47. data/app/frontend/components/settings/PluginSettings.tsx +81 -0
  48. data/app/frontend/components/ui/badge.tsx +44 -0
  49. data/app/frontend/components/ui/collapsible.tsx +9 -0
  50. data/app/frontend/components/ui/scroll-area.tsx +46 -0
  51. data/app/frontend/components/ui/separator.tsx +29 -0
  52. data/app/frontend/entrypoints/App.tsx +40 -0
  53. data/app/frontend/entrypoints/Application.tsx +24 -0
  54. data/app/frontend/hooks/useAutosave.ts +61 -0
  55. data/app/frontend/layouts/Layout.tsx +38 -0
  56. data/app/frontend/lib/utils.ts +6 -0
  57. data/app/frontend/mockData.ts +272 -0
  58. data/app/frontend/pages/DatasetDetailsPage.tsx +103 -0
  59. data/app/frontend/pages/DatasetsPage.tsx +261 -0
  60. data/app/frontend/pages/DatasourceFormPage.tsx +147 -0
  61. data/app/frontend/pages/DatasourcesPage.tsx +261 -0
  62. data/app/frontend/pages/EditModelPage.tsx +45 -0
  63. data/app/frontend/pages/EditTransformationPage.tsx +56 -0
  64. data/app/frontend/pages/ModelsPage.tsx +115 -0
  65. data/app/frontend/pages/NewDatasetPage.tsx +366 -0
  66. data/app/frontend/pages/NewModelPage.tsx +45 -0
  67. data/app/frontend/pages/NewTransformationPage.tsx +43 -0
  68. data/app/frontend/pages/SettingsPage.tsx +272 -0
  69. data/app/frontend/pages/ShowModelPage.tsx +30 -0
  70. data/app/frontend/pages/TransformationsPage.tsx +95 -0
  71. data/app/frontend/styles/application.css +100 -0
  72. data/app/frontend/types/dataset.ts +146 -0
  73. data/app/frontend/types/datasource.ts +33 -0
  74. data/app/frontend/types/preprocessing.ts +1 -0
  75. data/app/frontend/types.ts +113 -0
  76. data/app/helpers/easy_ml/application_helper.rb +10 -0
  77. data/app/jobs/easy_ml/application_job.rb +21 -0
  78. data/app/jobs/easy_ml/batch_job.rb +46 -0
  79. data/app/jobs/easy_ml/compute_feature_job.rb +19 -0
  80. data/app/jobs/easy_ml/deploy_job.rb +13 -0
  81. data/app/jobs/easy_ml/finalize_feature_job.rb +15 -0
  82. data/app/jobs/easy_ml/refresh_dataset_job.rb +32 -0
  83. data/app/jobs/easy_ml/schedule_retraining_job.rb +11 -0
  84. data/app/jobs/easy_ml/sync_datasource_job.rb +17 -0
  85. data/app/jobs/easy_ml/training_job.rb +62 -0
  86. data/app/models/easy_ml/adapters/base_adapter.rb +45 -0
  87. data/app/models/easy_ml/adapters/polars_adapter.rb +77 -0
  88. data/app/models/easy_ml/cleaner.rb +82 -0
  89. data/app/models/easy_ml/column.rb +124 -0
  90. data/app/models/easy_ml/column_history.rb +30 -0
  91. data/app/models/easy_ml/column_list.rb +122 -0
  92. data/app/models/easy_ml/concerns/configurable.rb +61 -0
  93. data/app/models/easy_ml/concerns/versionable.rb +19 -0
  94. data/app/models/easy_ml/dataset.rb +767 -0
  95. data/app/models/easy_ml/dataset_history.rb +56 -0
  96. data/app/models/easy_ml/datasource.rb +182 -0
  97. data/app/models/easy_ml/datasource_history.rb +24 -0
  98. data/app/models/easy_ml/datasources/base_datasource.rb +54 -0
  99. data/app/models/easy_ml/datasources/file_datasource.rb +58 -0
  100. data/app/models/easy_ml/datasources/polars_datasource.rb +89 -0
  101. data/app/models/easy_ml/datasources/s3_datasource.rb +97 -0
  102. data/app/models/easy_ml/deploy.rb +114 -0
  103. data/app/models/easy_ml/event.rb +79 -0
  104. data/app/models/easy_ml/feature.rb +437 -0
  105. data/app/models/easy_ml/feature_history.rb +38 -0
  106. data/app/models/easy_ml/model.rb +575 -41
  107. data/app/models/easy_ml/model_file.rb +133 -0
  108. data/app/models/easy_ml/model_file_history.rb +24 -0
  109. data/app/models/easy_ml/model_history.rb +51 -0
  110. data/app/models/easy_ml/models/base_model.rb +58 -0
  111. data/app/models/easy_ml/models/hyperparameters/base.rb +99 -0
  112. data/app/models/easy_ml/models/hyperparameters/xgboost/dart.rb +82 -0
  113. data/app/models/easy_ml/models/hyperparameters/xgboost/gblinear.rb +82 -0
  114. data/app/models/easy_ml/models/hyperparameters/xgboost/gbtree.rb +97 -0
  115. data/app/models/easy_ml/models/hyperparameters/xgboost.rb +71 -0
  116. data/app/models/easy_ml/models/xgboost/evals_callback.rb +138 -0
  117. data/app/models/easy_ml/models/xgboost/progress_callback.rb +39 -0
  118. data/app/models/easy_ml/models/xgboost.rb +544 -5
  119. data/app/models/easy_ml/prediction.rb +44 -0
  120. data/app/models/easy_ml/retraining_job.rb +278 -0
  121. data/app/models/easy_ml/retraining_run.rb +184 -0
  122. data/app/models/easy_ml/settings.rb +37 -0
  123. data/app/models/easy_ml/splitter.rb +90 -0
  124. data/app/models/easy_ml/splitters/base_splitter.rb +28 -0
  125. data/app/models/easy_ml/splitters/date_splitter.rb +91 -0
  126. data/app/models/easy_ml/splitters/predefined_splitter.rb +74 -0
  127. data/app/models/easy_ml/splitters/random_splitter.rb +82 -0
  128. data/app/models/easy_ml/tuner_job.rb +56 -0
  129. data/app/models/easy_ml/tuner_run.rb +31 -0
  130. data/app/models/splitter_history.rb +6 -0
  131. data/app/serializers/easy_ml/column_serializer.rb +27 -0
  132. data/app/serializers/easy_ml/dataset_serializer.rb +73 -0
  133. data/app/serializers/easy_ml/datasource_serializer.rb +64 -0
  134. data/app/serializers/easy_ml/feature_serializer.rb +27 -0
  135. data/app/serializers/easy_ml/model_serializer.rb +90 -0
  136. data/app/serializers/easy_ml/retraining_job_serializer.rb +22 -0
  137. data/app/serializers/easy_ml/retraining_run_serializer.rb +39 -0
  138. data/app/serializers/easy_ml/settings_serializer.rb +9 -0
  139. data/app/views/layouts/easy_ml/application.html.erb +15 -0
  140. data/config/initializers/resque.rb +3 -0
  141. data/config/resque-pool.yml +6 -0
  142. data/config/routes.rb +39 -0
  143. data/config/spring.rb +1 -0
  144. data/config/vite.json +15 -0
  145. data/lib/easy_ml/configuration.rb +64 -0
  146. data/lib/easy_ml/core/evaluators/base_evaluator.rb +53 -0
  147. data/lib/easy_ml/core/evaluators/classification_evaluators.rb +126 -0
  148. data/lib/easy_ml/core/evaluators/regression_evaluators.rb +66 -0
  149. data/lib/easy_ml/core/model_evaluator.rb +161 -89
  150. data/lib/easy_ml/core/tuner/adapters/base_adapter.rb +28 -18
  151. data/lib/easy_ml/core/tuner/adapters/xgboost_adapter.rb +4 -25
  152. data/lib/easy_ml/core/tuner.rb +123 -62
  153. data/lib/easy_ml/core.rb +0 -3
  154. data/lib/easy_ml/core_ext/hash.rb +24 -0
  155. data/lib/easy_ml/core_ext/pathname.rb +11 -5
  156. data/lib/easy_ml/data/date_converter.rb +90 -0
  157. data/lib/easy_ml/data/filter_extensions.rb +31 -0
  158. data/lib/easy_ml/data/polars_column.rb +126 -0
  159. data/lib/easy_ml/data/polars_reader.rb +297 -0
  160. data/lib/easy_ml/data/preprocessor.rb +280 -142
  161. data/lib/easy_ml/data/simple_imputer.rb +255 -0
  162. data/lib/easy_ml/data/splits/file_split.rb +252 -0
  163. data/lib/easy_ml/data/splits/in_memory_split.rb +54 -0
  164. data/lib/easy_ml/data/splits/split.rb +95 -0
  165. data/lib/easy_ml/data/splits.rb +9 -0
  166. data/lib/easy_ml/data/statistics_learner.rb +93 -0
  167. data/lib/easy_ml/data/synced_directory.rb +341 -0
  168. data/lib/easy_ml/data.rb +6 -2
  169. data/lib/easy_ml/engine.rb +105 -6
  170. data/lib/easy_ml/feature_store.rb +227 -0
  171. data/lib/easy_ml/features.rb +61 -0
  172. data/lib/easy_ml/initializers/inflections.rb +17 -3
  173. data/lib/easy_ml/logging.rb +2 -2
  174. data/lib/easy_ml/predict.rb +74 -0
  175. data/lib/easy_ml/railtie/generators/migration/migration_generator.rb +192 -36
  176. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_column_histories.rb.tt +9 -0
  177. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_columns.rb.tt +25 -0
  178. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_dataset_histories.rb.tt +9 -0
  179. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_datasets.rb.tt +31 -0
  180. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_datasource_histories.rb.tt +9 -0
  181. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_datasources.rb.tt +16 -0
  182. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_deploys.rb.tt +24 -0
  183. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_events.rb.tt +20 -0
  184. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_feature_histories.rb.tt +14 -0
  185. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_features.rb.tt +32 -0
  186. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_model_file_histories.rb.tt +9 -0
  187. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_model_files.rb.tt +17 -0
  188. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_model_histories.rb.tt +9 -0
  189. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_models.rb.tt +20 -9
  190. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_predictions.rb.tt +17 -0
  191. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_retraining_jobs.rb.tt +77 -0
  192. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_settings.rb.tt +9 -0
  193. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_splitter_histories.rb.tt +9 -0
  194. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_splitters.rb.tt +15 -0
  195. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_tuner_jobs.rb.tt +40 -0
  196. data/lib/easy_ml/support/est.rb +5 -1
  197. data/lib/easy_ml/support/file_rotate.rb +79 -15
  198. data/lib/easy_ml/support/file_support.rb +9 -0
  199. data/lib/easy_ml/support/local_file.rb +24 -0
  200. data/lib/easy_ml/support/lockable.rb +62 -0
  201. data/lib/easy_ml/support/synced_file.rb +103 -0
  202. data/lib/easy_ml/support/utc.rb +5 -1
  203. data/lib/easy_ml/support.rb +6 -3
  204. data/lib/easy_ml/version.rb +4 -1
  205. data/lib/easy_ml.rb +7 -2
  206. metadata +355 -72
  207. data/app/models/easy_ml/models.rb +0 -5
  208. data/lib/easy_ml/core/model.rb +0 -30
  209. data/lib/easy_ml/core/model_core.rb +0 -181
  210. data/lib/easy_ml/core/models/hyperparameters/base.rb +0 -34
  211. data/lib/easy_ml/core/models/hyperparameters/xgboost.rb +0 -19
  212. data/lib/easy_ml/core/models/xgboost.rb +0 -10
  213. data/lib/easy_ml/core/models/xgboost_core.rb +0 -220
  214. data/lib/easy_ml/core/models.rb +0 -10
  215. data/lib/easy_ml/core/uploaders/model_uploader.rb +0 -24
  216. data/lib/easy_ml/core/uploaders.rb +0 -7
  217. data/lib/easy_ml/data/dataloader.rb +0 -6
  218. data/lib/easy_ml/data/dataset/data/preprocessor/statistics.json +0 -31
  219. data/lib/easy_ml/data/dataset/data/sample_info.json +0 -1
  220. data/lib/easy_ml/data/dataset/dataset/files/sample_info.json +0 -1
  221. data/lib/easy_ml/data/dataset/splits/file_split.rb +0 -140
  222. data/lib/easy_ml/data/dataset/splits/in_memory_split.rb +0 -49
  223. data/lib/easy_ml/data/dataset/splits/split.rb +0 -98
  224. data/lib/easy_ml/data/dataset/splits.rb +0 -11
  225. data/lib/easy_ml/data/dataset/splitters/date_splitter.rb +0 -43
  226. data/lib/easy_ml/data/dataset/splitters.rb +0 -9
  227. data/lib/easy_ml/data/dataset.rb +0 -430
  228. data/lib/easy_ml/data/datasource/datasource_factory.rb +0 -60
  229. data/lib/easy_ml/data/datasource/file_datasource.rb +0 -40
  230. data/lib/easy_ml/data/datasource/merged_datasource.rb +0 -64
  231. data/lib/easy_ml/data/datasource/polars_datasource.rb +0 -41
  232. data/lib/easy_ml/data/datasource/s3_datasource.rb +0 -89
  233. data/lib/easy_ml/data/datasource.rb +0 -33
  234. data/lib/easy_ml/data/preprocessor/preprocessor.rb +0 -205
  235. data/lib/easy_ml/data/preprocessor/simple_imputer.rb +0 -402
  236. data/lib/easy_ml/deployment.rb +0 -5
  237. data/lib/easy_ml/support/synced_directory.rb +0 -134
  238. data/lib/easy_ml/transforms.rb +0 -29
  239. /data/{lib/easy_ml/core → app/models/easy_ml}/models/hyperparameters.rb +0 -0
@@ -1,181 +0,0 @@
1
- require "carrierwave"
2
- require_relative "uploaders/model_uploader"
3
-
4
- module EasyML
5
- module Core
6
- module ModelCore
7
- attr_accessor :dataset
8
-
9
- def self.included(base)
10
- base.send(:include, GlueGun::DSL)
11
- base.send(:extend, CarrierWave::Mount)
12
- base.send(:mount_uploader, :file, EasyML::Core::Uploaders::ModelUploader)
13
-
14
- base.class_eval do
15
- validates :task, inclusion: { in: %w[regression classification] }
16
- validates :task, presence: true
17
- validate :dataset_is_a_dataset?
18
- validate :validate_any_metrics?
19
- validate :validate_metrics_for_task
20
- before_validation :save_model_file, if: -> { fit? }
21
- end
22
- end
23
-
24
- def fit(x_train: nil, y_train: nil, x_valid: nil, y_valid: nil)
25
- if x_train.nil?
26
- dataset.refresh!
27
- train_in_batches
28
- else
29
- train(x_train, y_train, x_valid, y_valid)
30
- end
31
- @is_fit = true
32
- end
33
-
34
- def predict(xs)
35
- raise NotImplementedError, "Subclasses must implement predict method"
36
- end
37
-
38
- def load
39
- raise NotImplementedError, "Subclasses must implement load method"
40
- end
41
-
42
- def _save_model_file
43
- raise NotImplementedError, "Subclasses must implement _save_model_file method"
44
- end
45
-
46
- def save
47
- super if defined?(super) && self.class.superclass.method_defined?(:save)
48
- save_model_file
49
- end
50
-
51
- def decode_labels(ys, col: nil)
52
- dataset.decode_labels(ys, col: col)
53
- end
54
-
55
- def evaluate(y_pred: nil, y_true: nil, x_true: nil, evaluator: nil)
56
- evaluator ||= self.evaluator
57
- EasyML::Core::ModelEvaluator.evaluate(model: self, y_pred: y_pred, y_true: y_true, x_true: x_true,
58
- evaluator: evaluator)
59
- end
60
-
61
- def save_model_file
62
- raise "No trained model! Need to train model before saving (call model.fit)" unless fit?
63
-
64
- path = File.join(model_dir, "#{version}.json")
65
- ensure_directory_exists(File.dirname(path))
66
-
67
- _save_model_file(path)
68
-
69
- File.open(path) do |f|
70
- self.file = f
71
- end
72
- file.store!
73
-
74
- cleanup
75
- end
76
-
77
- def get_params
78
- @hyperparameters.to_h
79
- end
80
-
81
- def allowed_metrics
82
- return [] unless task.present?
83
-
84
- case task.to_sym
85
- when :regression
86
- %w[mean_absolute_error mean_squared_error root_mean_squared_error r2_score]
87
- when :classification
88
- %w[accuracy_score precision_score recall_score f1_score auc roc_auc]
89
- else
90
- []
91
- end
92
- end
93
-
94
- def cleanup!
95
- [carrierwave_dir, model_dir].each do |dir|
96
- EasyML::FileRotate.new(dir, []).cleanup(extension_allowlist)
97
- end
98
- end
99
-
100
- def cleanup
101
- [carrierwave_dir, model_dir].each do |dir|
102
- EasyML::FileRotate.new(dir, files_to_keep).cleanup(extension_allowlist)
103
- end
104
- end
105
-
106
- def fit?
107
- @is_fit == true
108
- end
109
-
110
- private
111
-
112
- def carrierwave_dir
113
- return unless file.path.present?
114
-
115
- File.dirname(file.path).split("/")[0..-2].join("/")
116
- end
117
-
118
- def extension_allowlist
119
- EasyML::Core::Uploaders::ModelUploader.new.extension_allowlist
120
- end
121
-
122
- def _save_model_file(path = nil)
123
- raise NotImplementedError, "Subclasses must implement _save_model_file method"
124
- end
125
-
126
- def ensure_directory_exists(dir)
127
- FileUtils.mkdir_p(dir) unless File.directory?(dir)
128
- end
129
-
130
- def apply_defaults
131
- self.version ||= generate_version_string
132
- self.metrics ||= allowed_metrics
133
- self.ml_model ||= get_ml_model
134
- end
135
-
136
- def get_ml_model
137
- self.class.name.split("::").last.underscore
138
- end
139
-
140
- def generate_version_string
141
- timestamp = Time.now.utc.strftime("%Y%m%d%H%M%S")
142
- model_name = self.class.name.split("::").last.underscore
143
- "#{model_name}_#{timestamp}"
144
- end
145
-
146
- def model_dir
147
- File.join(root_dir, "easy_ml_models", name.present? ? name.split.join.underscore : "")
148
- end
149
-
150
- def files_to_keep
151
- Dir.glob(File.join(carrierwave_dir, "**/*")).select { |f| File.file?(f) }.sort_by do |filename|
152
- Time.parse(filename.split("/").last.gsub(/\D/, ""))
153
- end.reverse.take(5)
154
- end
155
-
156
- def dataset_is_a_dataset?
157
- return if dataset.nil?
158
- return if dataset.class.ancestors.include?(EasyML::Data::Dataset)
159
-
160
- errors.add(:dataset, "Must be a subclass of EasyML::Dataset")
161
- end
162
-
163
- def validate_any_metrics?
164
- return if metrics.any?
165
-
166
- errors.add(:metrics, "Must include at least one metric. Allowed metrics are #{allowed_metrics.join(", ")}")
167
- end
168
-
169
- def validate_metrics_for_task
170
- nonsensical_metrics = metrics.select do |metric|
171
- allowed_metrics.exclude?(metric)
172
- end
173
-
174
- return unless nonsensical_metrics.any?
175
-
176
- errors.add(:metrics,
177
- "cannot use metrics: #{nonsensical_metrics.join(", ")} for task #{task}. Allowed metrics are: #{allowed_metrics.join(", ")}")
178
- end
179
- end
180
- end
181
- end
@@ -1,34 +0,0 @@
1
- module EasyML
2
- module Models
3
- module Hyperparameters
4
- class Base
5
- include GlueGun::DSL
6
-
7
- attribute :learning_rate, :float, default: 0.01
8
- attribute :max_iterations, :integer, default: 100
9
- attribute :batch_size, :integer, default: 32
10
- attribute :regularization, :float, default: 0.0001
11
-
12
- def to_h
13
- attributes
14
- end
15
-
16
- def merge(other)
17
- return self if other.nil?
18
-
19
- other_hash = other.is_a?(Hyperparameters) ? other.to_h : other
20
- merged_hash = to_h.merge(other_hash)
21
- self.class.new(**merged_hash)
22
- end
23
-
24
- def [](key)
25
- send(key) if respond_to?(key)
26
- end
27
-
28
- def []=(key, value)
29
- send("#{key}=", value) if respond_to?("#{key}=")
30
- end
31
- end
32
- end
33
- end
34
- end
@@ -1,19 +0,0 @@
1
- module EasyML
2
- module Models
3
- module Hyperparameters
4
- class XGBoost < Base
5
- include GlueGun::DSL
6
-
7
- attribute :learning_rate, :float, default: 0.1
8
- attribute :max_depth, :integer, default: 6
9
- attribute :n_estimators, :integer, default: 100
10
- attribute :booster, :string, default: "gbtree"
11
- attribute :objective, :string, default: "reg:squarederror"
12
-
13
- validates :objective,
14
- inclusion: { in: %w[binary:logistic binary:hinge multi:softmax multi:softprob reg:squarederror
15
- reg:logistic] }
16
- end
17
- end
18
- end
19
- end
@@ -1,10 +0,0 @@
1
- require_relative "xgboost_core"
2
- module EasyML
3
- module Core
4
- module Models
5
- class XGBoost < EasyML::Core::Model
6
- include XGBoostCore
7
- end
8
- end
9
- end
10
- end
@@ -1,220 +0,0 @@
1
- require "wandb"
2
- module EasyML
3
- module Core
4
- module Models
5
- module XGBoostCore
6
- OBJECTIVES = {
7
- classification: {
8
- binary: %w[binary:logistic binary:hinge],
9
- multi_class: %w[multi:softmax multi:softprob]
10
- },
11
- regression: %w[reg:squarederror reg:logistic]
12
- }
13
-
14
- def self.included(base)
15
- base.class_eval do
16
- attribute :evaluator
17
-
18
- dependency :callbacks, { array: true } do |dep|
19
- dep.option :wandb do |opt|
20
- opt.set_class Wandb::XGBoostCallback
21
- opt.bind_attribute :log_model, default: false
22
- opt.bind_attribute :log_feature_importance, default: true
23
- opt.bind_attribute :importance_type, default: "gain"
24
- opt.bind_attribute :define_metric, default: true
25
- opt.bind_attribute :project_name
26
- end
27
- end
28
-
29
- dependency :hyperparameters do |dep|
30
- dep.set_class EasyML::Models::Hyperparameters::XGBoost
31
- dep.bind_attribute :batch_size, default: 32
32
- dep.bind_attribute :learning_rate, default: 1.1
33
- dep.bind_attribute :max_depth, default: 6
34
- dep.bind_attribute :n_estimators, default: 100
35
- dep.bind_attribute :booster, default: "gbtree"
36
- dep.bind_attribute :objective, default: "reg:squarederror"
37
- end
38
- end
39
- end
40
-
41
- attr_accessor :model, :booster
42
-
43
- def predict(xs)
44
- raise "No trained model! Train a model before calling predict" unless @booster.present?
45
- raise "Cannot predict on nil — XGBoost" if xs.nil?
46
-
47
- y_pred = @booster.predict(preprocess(xs))
48
-
49
- case task.to_sym
50
- when :classification
51
- to_classification(y_pred)
52
- else
53
- y_pred
54
- end
55
- end
56
-
57
- def predict_proba(data)
58
- dmat = DMatrix.new(data)
59
- y_pred = @booster.predict(dmat)
60
-
61
- if y_pred.first.is_a?(Array)
62
- # multiple classes
63
- y_pred
64
- else
65
- y_pred.map { |v| [1 - v, v] }
66
- end
67
- end
68
-
69
- def load(path = nil)
70
- path ||= file
71
- path = path&.file&.file if path.class.ancestors.include?(CarrierWave::Uploader::Base)
72
-
73
- raise "No existing model at #{path}" unless File.exist?(path)
74
-
75
- initialize_model do
76
- booster_class.new(params: hyperparameters.to_h, model_file: path)
77
- end
78
- end
79
-
80
- def _save_model_file(path)
81
- puts "XGBoost received path #{path}"
82
- @booster.save_model(path)
83
- end
84
-
85
- def feature_importances
86
- @model.booster.feature_names.zip(@model.feature_importances).to_h
87
- end
88
-
89
- def base_model
90
- ::XGBoost
91
- end
92
-
93
- def customize_callbacks
94
- yield callbacks
95
- end
96
-
97
- private
98
-
99
- def booster_class
100
- ::XGBoost::Booster
101
- end
102
-
103
- def d_matrix_class
104
- ::XGBoost::DMatrix
105
- end
106
-
107
- def model_class
108
- ::XGBoost::Model
109
- end
110
-
111
- def train
112
- validate_objective
113
-
114
- xs = xs.to_a.map(&:values)
115
- ys = ys.to_a.map(&:values)
116
- dtrain = d_matrix_class.new(xs, label: ys)
117
- @model = base_model.train(hyperparameters.to_h, dtrain, callbacks: callbacks)
118
- end
119
-
120
- def train_in_batches
121
- validate_objective
122
-
123
- # Initialize the model with the first batch
124
- @model = nil
125
- @booster = nil
126
- x_valid, y_valid = dataset.valid(split_ys: true)
127
- x_train, y_train = dataset.train(split_ys: true)
128
- fit_batch(x_train, y_train, x_valid, y_valid)
129
- end
130
-
131
- def _preprocess(df)
132
- df.to_a.map do |row|
133
- row.values.map do |value|
134
- case value
135
- when Time
136
- value.to_i # Convert Time to Unix timestamp
137
- when Date
138
- value.to_time.to_i # Convert Date to Unix timestamp
139
- when String
140
- value
141
- when TrueClass, FalseClass
142
- value ? 1.0 : 0.0 # Convert booleans to 1.0 and 0.0
143
- when Integer
144
- value
145
- else
146
- value.to_f # Ensure everything else is converted to a float
147
- end
148
- end
149
- end
150
- end
151
-
152
- def preprocess(xs, ys = nil)
153
- column_names = xs.columns
154
- xs = _preprocess(xs)
155
- ys = ys.nil? ? nil : _preprocess(ys).flatten
156
- kwargs = { label: ys }.compact
157
- ::XGBoost::DMatrix.new(xs, **kwargs).tap do |dmat|
158
- dmat.feature_names = column_names
159
- end
160
- end
161
-
162
- def initialize_model
163
- @model = model_class.new(n_estimators: @hyperparameters.to_h.dig(:n_estimators))
164
- @booster = yield
165
- @model.instance_variable_set(:@booster, @booster)
166
- end
167
-
168
- def validate_objective
169
- objective = hyperparameters.objective
170
- unless task.present?
171
- raise ArgumentError,
172
- "cannot train model without task. Please specify either regression or classification (model.task = :regression)"
173
- end
174
-
175
- case task.to_sym
176
- when :classification
177
- _, ys = dataset.data(split_ys: true)
178
- classification_type = ys[ys.columns.first].uniq.count <= 2 ? :binary : :multi_class
179
- allowed_objectives = OBJECTIVES[:classification][classification_type]
180
- else
181
- allowed_objectives = OBJECTIVES[task.to_sym]
182
- end
183
- return if allowed_objectives.map(&:to_sym).include?(objective.to_sym)
184
-
185
- raise ArgumentError,
186
- "cannot use #{objective} for #{task} task. Allowed objectives are: #{allowed_objectives.join(", ")}"
187
- end
188
-
189
- def fit_batch(x_train, y_train, x_valid, y_valid)
190
- d_train = preprocess(x_train, y_train)
191
- d_valid = preprocess(x_valid, y_valid)
192
-
193
- evals = [[d_train, "train"], [d_valid, "eval"]]
194
-
195
- # # If this is the first batch, create the booster
196
- if @booster.nil?
197
- initialize_model do
198
- base_model.train(@hyperparameters.to_h, d_train,
199
- num_boost_round: @hyperparameters.to_h.dig("n_estimators"), evals: evals, callbacks: callbacks)
200
- end
201
- else
202
- # Update the existing booster with the new batch
203
- @model.update(d_train)
204
- end
205
- end
206
-
207
- def to_classification(y_pred)
208
- if y_pred.first.is_a?(Array)
209
- # multiple classes
210
- y_pred.map do |v|
211
- v.map.with_index.max_by { |v2, _| v2 }.last
212
- end
213
- else
214
- y_pred.map { |v| v > 0.5 ? 1 : 0 }
215
- end
216
- end
217
- end
218
- end
219
- end
220
- end
@@ -1,10 +0,0 @@
1
- module EasyML
2
- module Core
3
- module Models
4
- require_relative "models/hyperparameters"
5
- require_relative "models/xgboost"
6
-
7
- AVAILABLE_MODELS = [XGBoost]
8
- end
9
- end
10
- end
@@ -1,24 +0,0 @@
1
- require "carrierwave"
2
-
3
- module EasyML
4
- module Core
5
- module Uploaders
6
- class ModelUploader < CarrierWave::Uploader::Base
7
- # Choose storage type
8
- if Rails.env.production?
9
- storage :fog
10
- else
11
- storage :file
12
- end
13
-
14
- def store_dir
15
- "easy_ml_models/#{model.name}"
16
- end
17
-
18
- def extension_allowlist
19
- %w[bin model json]
20
- end
21
- end
22
- end
23
- end
24
- end
@@ -1,7 +0,0 @@
1
- module EasyML
2
- module Core
3
- module Uploaders
4
- require_relative "uploaders/model_uploader"
5
- end
6
- end
7
- end
@@ -1,6 +0,0 @@
1
- module ML
2
- module Data
3
- class Dataloader
4
- end
5
- end
6
- end
@@ -1,31 +0,0 @@
1
- {
2
- "annual_revenue": {
3
- "median": {
4
- "value": 3000.0,
5
- "original_dtype": {
6
- "__type__": "polars_dtype",
7
- "value": "Polars::Int64"
8
- }
9
- }
10
- },
11
- "loan_purpose": {
12
- "categorical": {
13
- "value": {
14
- "payroll": 4,
15
- "expansion": 1
16
- },
17
- "label_encoder": {
18
- "expansion": 0,
19
- "payroll": 1
20
- },
21
- "label_decoder": {
22
- "0": "expansion",
23
- "1": "payroll"
24
- },
25
- "original_dtype": {
26
- "__type__": "polars_dtype",
27
- "value": "Polars::String"
28
- }
29
- }
30
- }
31
- }
@@ -1 +0,0 @@
1
- {"previous_sample":1.0}
@@ -1 +0,0 @@
1
- {"previous_sample":1.0}
@@ -1,140 +0,0 @@
1
- require_relative "split"
2
-
3
- module EasyML
4
- module Data
5
- class Dataset
6
- module Splits
7
- class FileSplit < Split
8
- include GlueGun::DSL
9
- include EasyML::Data::Utils
10
-
11
- attribute :dir, :string
12
- attribute :polars_args, :hash, default: {}
13
- attribute :max_rows_per_file, :integer, default: 1_000_000
14
- attribute :batch_size, :integer, default: 10_000
15
- attribute :sample, :float, default: 1.0
16
- attribute :verbose, :boolean, default: false
17
-
18
- def initialize(options)
19
- super
20
- FileUtils.mkdir_p(dir)
21
- end
22
-
23
- def save(segment, df)
24
- segment_dir = File.join(dir, segment.to_s)
25
- FileUtils.mkdir_p(segment_dir)
26
-
27
- current_file = current_file_for_segment(segment)
28
- current_row_count = current_file && File.exist?(current_file) ? df(current_file).shape[0] : 0
29
- remaining_rows = max_rows_per_file - current_row_count
30
-
31
- while df.shape[0] > 0
32
- if df.shape[0] <= remaining_rows
33
- append_to_csv(df, current_file)
34
- break
35
- else
36
- df_to_append = df.slice(0, remaining_rows)
37
- df = df.slice(remaining_rows, df.shape[0] - remaining_rows)
38
- append_to_csv(df_to_append, current_file)
39
- current_file = new_file_path_for_segment(segment)
40
- remaining_rows = max_rows_per_file
41
- end
42
- end
43
- end
44
-
45
- def read(segment, split_ys: false, target: nil, drop_cols: [], &block)
46
- files = files_for_segment(segment)
47
-
48
- if block_given?
49
- result = nil
50
- total_rows = files.sum { |file| df(file).shape[0] }
51
- progress_bar = create_progress_bar(segment, total_rows) if verbose
52
-
53
- files.each do |file|
54
- df = self.df(file)
55
- df = sample_data(df) if sample < 1.0
56
- drop_cols &= df.columns
57
- df = df.drop(drop_cols) unless drop_cols.empty?
58
-
59
- if split_ys
60
- xs, ys = split_features_targets(df, true, target)
61
- result = process_block_with_split_ys(block, result, xs, ys)
62
- else
63
- result = process_block_without_split_ys(block, result, df)
64
- end
65
-
66
- progress_bar.progress += df.shape[0] if verbose
67
- end
68
- progress_bar.finish if verbose
69
- result
70
- elsif files.empty?
71
- return nil, nil if split_ys
72
-
73
- nil
74
-
75
- else
76
- combined_df = combine_dataframes(files)
77
- combined_df = sample_data(combined_df) if sample < 1.0
78
- drop_cols &= combined_df.columns
79
- combined_df = combined_df.drop(drop_cols) unless drop_cols.empty?
80
- split_features_targets(combined_df, split_ys, target)
81
- end
82
- end
83
-
84
- def cleanup
85
- FileUtils.rm_rf(dir)
86
- FileUtils.mkdir_p(dir)
87
- end
88
-
89
- def split_at
90
- return nil if output_files.empty?
91
-
92
- output_files.map { |file| File.mtime(file) }.max
93
- end
94
-
95
- private
96
-
97
- def read_csv_batched(path)
98
- Polars.read_csv_batched(path, batch_size: batch_size, **polars_args)
99
- end
100
-
101
- def df(path)
102
- Polars.read_csv(path, **polars_args)
103
- end
104
-
105
- def output_files
106
- Dir.glob("#{dir}/**/*.csv")
107
- end
108
-
109
- def files_for_segment(segment)
110
- segment_dir = File.join(dir, segment.to_s)
111
- Dir.glob(File.join(segment_dir, "**/*.csv")).sort
112
- end
113
-
114
- def current_file_for_segment(segment)
115
- current_file = files_for_segment(segment).last
116
- return new_file_path_for_segment(segment) if current_file.nil?
117
-
118
- row_count = df(current_file).shape[0]
119
- if row_count >= max_rows_per_file
120
- new_file_path_for_segment(segment)
121
- else
122
- current_file
123
- end
124
- end
125
-
126
- def new_file_path_for_segment(segment)
127
- segment_dir = File.join(dir, segment.to_s)
128
- file_number = Dir.glob(File.join(segment_dir, "*.csv")).count
129
- File.join(segment_dir, "#{segment}_%04d.csv" % file_number)
130
- end
131
-
132
- def combine_dataframes(files)
133
- dfs = files.map { |file| df(file) }
134
- Polars.concat(dfs)
135
- end
136
- end
137
- end
138
- end
139
- end
140
- end