easy_ml 0.1.4 → 0.2.0.pre.rc1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (239) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +234 -26
  3. data/Rakefile +45 -0
  4. data/app/controllers/easy_ml/application_controller.rb +67 -0
  5. data/app/controllers/easy_ml/columns_controller.rb +38 -0
  6. data/app/controllers/easy_ml/datasets_controller.rb +156 -0
  7. data/app/controllers/easy_ml/datasources_controller.rb +88 -0
  8. data/app/controllers/easy_ml/deploys_controller.rb +20 -0
  9. data/app/controllers/easy_ml/models_controller.rb +151 -0
  10. data/app/controllers/easy_ml/retraining_runs_controller.rb +19 -0
  11. data/app/controllers/easy_ml/settings_controller.rb +59 -0
  12. data/app/frontend/components/AlertProvider.tsx +108 -0
  13. data/app/frontend/components/DatasetPreview.tsx +161 -0
  14. data/app/frontend/components/EmptyState.tsx +28 -0
  15. data/app/frontend/components/ModelCard.tsx +255 -0
  16. data/app/frontend/components/ModelDetails.tsx +334 -0
  17. data/app/frontend/components/ModelForm.tsx +384 -0
  18. data/app/frontend/components/Navigation.tsx +300 -0
  19. data/app/frontend/components/Pagination.tsx +72 -0
  20. data/app/frontend/components/Popover.tsx +55 -0
  21. data/app/frontend/components/PredictionStream.tsx +105 -0
  22. data/app/frontend/components/ScheduleModal.tsx +726 -0
  23. data/app/frontend/components/SearchInput.tsx +23 -0
  24. data/app/frontend/components/SearchableSelect.tsx +132 -0
  25. data/app/frontend/components/dataset/AutosaveIndicator.tsx +39 -0
  26. data/app/frontend/components/dataset/ColumnConfigModal.tsx +431 -0
  27. data/app/frontend/components/dataset/ColumnFilters.tsx +256 -0
  28. data/app/frontend/components/dataset/ColumnList.tsx +101 -0
  29. data/app/frontend/components/dataset/FeatureConfigPopover.tsx +57 -0
  30. data/app/frontend/components/dataset/FeaturePicker.tsx +205 -0
  31. data/app/frontend/components/dataset/PreprocessingConfig.tsx +704 -0
  32. data/app/frontend/components/dataset/SplitConfigurator.tsx +120 -0
  33. data/app/frontend/components/dataset/splitters/DateSplitter.tsx +58 -0
  34. data/app/frontend/components/dataset/splitters/KFoldSplitter.tsx +68 -0
  35. data/app/frontend/components/dataset/splitters/LeavePOutSplitter.tsx +29 -0
  36. data/app/frontend/components/dataset/splitters/PredefinedSplitter.tsx +146 -0
  37. data/app/frontend/components/dataset/splitters/RandomSplitter.tsx +85 -0
  38. data/app/frontend/components/dataset/splitters/StratifiedSplitter.tsx +79 -0
  39. data/app/frontend/components/dataset/splitters/constants.ts +77 -0
  40. data/app/frontend/components/dataset/splitters/types.ts +168 -0
  41. data/app/frontend/components/dataset/splitters/utils.ts +53 -0
  42. data/app/frontend/components/features/CodeEditor.tsx +46 -0
  43. data/app/frontend/components/features/DataPreview.tsx +150 -0
  44. data/app/frontend/components/features/FeatureCard.tsx +88 -0
  45. data/app/frontend/components/features/FeatureForm.tsx +235 -0
  46. data/app/frontend/components/features/FeatureGroupCard.tsx +54 -0
  47. data/app/frontend/components/settings/PluginSettings.tsx +81 -0
  48. data/app/frontend/components/ui/badge.tsx +44 -0
  49. data/app/frontend/components/ui/collapsible.tsx +9 -0
  50. data/app/frontend/components/ui/scroll-area.tsx +46 -0
  51. data/app/frontend/components/ui/separator.tsx +29 -0
  52. data/app/frontend/entrypoints/App.tsx +40 -0
  53. data/app/frontend/entrypoints/Application.tsx +24 -0
  54. data/app/frontend/hooks/useAutosave.ts +61 -0
  55. data/app/frontend/layouts/Layout.tsx +38 -0
  56. data/app/frontend/lib/utils.ts +6 -0
  57. data/app/frontend/mockData.ts +272 -0
  58. data/app/frontend/pages/DatasetDetailsPage.tsx +103 -0
  59. data/app/frontend/pages/DatasetsPage.tsx +261 -0
  60. data/app/frontend/pages/DatasourceFormPage.tsx +147 -0
  61. data/app/frontend/pages/DatasourcesPage.tsx +261 -0
  62. data/app/frontend/pages/EditModelPage.tsx +45 -0
  63. data/app/frontend/pages/EditTransformationPage.tsx +56 -0
  64. data/app/frontend/pages/ModelsPage.tsx +115 -0
  65. data/app/frontend/pages/NewDatasetPage.tsx +366 -0
  66. data/app/frontend/pages/NewModelPage.tsx +45 -0
  67. data/app/frontend/pages/NewTransformationPage.tsx +43 -0
  68. data/app/frontend/pages/SettingsPage.tsx +272 -0
  69. data/app/frontend/pages/ShowModelPage.tsx +30 -0
  70. data/app/frontend/pages/TransformationsPage.tsx +95 -0
  71. data/app/frontend/styles/application.css +100 -0
  72. data/app/frontend/types/dataset.ts +146 -0
  73. data/app/frontend/types/datasource.ts +33 -0
  74. data/app/frontend/types/preprocessing.ts +1 -0
  75. data/app/frontend/types.ts +113 -0
  76. data/app/helpers/easy_ml/application_helper.rb +10 -0
  77. data/app/jobs/easy_ml/application_job.rb +21 -0
  78. data/app/jobs/easy_ml/batch_job.rb +46 -0
  79. data/app/jobs/easy_ml/compute_feature_job.rb +19 -0
  80. data/app/jobs/easy_ml/deploy_job.rb +13 -0
  81. data/app/jobs/easy_ml/finalize_feature_job.rb +15 -0
  82. data/app/jobs/easy_ml/refresh_dataset_job.rb +32 -0
  83. data/app/jobs/easy_ml/schedule_retraining_job.rb +11 -0
  84. data/app/jobs/easy_ml/sync_datasource_job.rb +17 -0
  85. data/app/jobs/easy_ml/training_job.rb +62 -0
  86. data/app/models/easy_ml/adapters/base_adapter.rb +45 -0
  87. data/app/models/easy_ml/adapters/polars_adapter.rb +77 -0
  88. data/app/models/easy_ml/cleaner.rb +82 -0
  89. data/app/models/easy_ml/column.rb +124 -0
  90. data/app/models/easy_ml/column_history.rb +30 -0
  91. data/app/models/easy_ml/column_list.rb +122 -0
  92. data/app/models/easy_ml/concerns/configurable.rb +61 -0
  93. data/app/models/easy_ml/concerns/versionable.rb +19 -0
  94. data/app/models/easy_ml/dataset.rb +767 -0
  95. data/app/models/easy_ml/dataset_history.rb +56 -0
  96. data/app/models/easy_ml/datasource.rb +182 -0
  97. data/app/models/easy_ml/datasource_history.rb +24 -0
  98. data/app/models/easy_ml/datasources/base_datasource.rb +54 -0
  99. data/app/models/easy_ml/datasources/file_datasource.rb +58 -0
  100. data/app/models/easy_ml/datasources/polars_datasource.rb +89 -0
  101. data/app/models/easy_ml/datasources/s3_datasource.rb +97 -0
  102. data/app/models/easy_ml/deploy.rb +114 -0
  103. data/app/models/easy_ml/event.rb +79 -0
  104. data/app/models/easy_ml/feature.rb +437 -0
  105. data/app/models/easy_ml/feature_history.rb +38 -0
  106. data/app/models/easy_ml/model.rb +575 -41
  107. data/app/models/easy_ml/model_file.rb +133 -0
  108. data/app/models/easy_ml/model_file_history.rb +24 -0
  109. data/app/models/easy_ml/model_history.rb +51 -0
  110. data/app/models/easy_ml/models/base_model.rb +58 -0
  111. data/app/models/easy_ml/models/hyperparameters/base.rb +99 -0
  112. data/app/models/easy_ml/models/hyperparameters/xgboost/dart.rb +82 -0
  113. data/app/models/easy_ml/models/hyperparameters/xgboost/gblinear.rb +82 -0
  114. data/app/models/easy_ml/models/hyperparameters/xgboost/gbtree.rb +97 -0
  115. data/app/models/easy_ml/models/hyperparameters/xgboost.rb +71 -0
  116. data/app/models/easy_ml/models/xgboost/evals_callback.rb +138 -0
  117. data/app/models/easy_ml/models/xgboost/progress_callback.rb +39 -0
  118. data/app/models/easy_ml/models/xgboost.rb +544 -5
  119. data/app/models/easy_ml/prediction.rb +44 -0
  120. data/app/models/easy_ml/retraining_job.rb +278 -0
  121. data/app/models/easy_ml/retraining_run.rb +184 -0
  122. data/app/models/easy_ml/settings.rb +37 -0
  123. data/app/models/easy_ml/splitter.rb +90 -0
  124. data/app/models/easy_ml/splitters/base_splitter.rb +28 -0
  125. data/app/models/easy_ml/splitters/date_splitter.rb +91 -0
  126. data/app/models/easy_ml/splitters/predefined_splitter.rb +74 -0
  127. data/app/models/easy_ml/splitters/random_splitter.rb +82 -0
  128. data/app/models/easy_ml/tuner_job.rb +56 -0
  129. data/app/models/easy_ml/tuner_run.rb +31 -0
  130. data/app/models/splitter_history.rb +6 -0
  131. data/app/serializers/easy_ml/column_serializer.rb +27 -0
  132. data/app/serializers/easy_ml/dataset_serializer.rb +73 -0
  133. data/app/serializers/easy_ml/datasource_serializer.rb +64 -0
  134. data/app/serializers/easy_ml/feature_serializer.rb +27 -0
  135. data/app/serializers/easy_ml/model_serializer.rb +90 -0
  136. data/app/serializers/easy_ml/retraining_job_serializer.rb +22 -0
  137. data/app/serializers/easy_ml/retraining_run_serializer.rb +39 -0
  138. data/app/serializers/easy_ml/settings_serializer.rb +9 -0
  139. data/app/views/layouts/easy_ml/application.html.erb +15 -0
  140. data/config/initializers/resque.rb +3 -0
  141. data/config/resque-pool.yml +6 -0
  142. data/config/routes.rb +39 -0
  143. data/config/spring.rb +1 -0
  144. data/config/vite.json +15 -0
  145. data/lib/easy_ml/configuration.rb +64 -0
  146. data/lib/easy_ml/core/evaluators/base_evaluator.rb +53 -0
  147. data/lib/easy_ml/core/evaluators/classification_evaluators.rb +126 -0
  148. data/lib/easy_ml/core/evaluators/regression_evaluators.rb +66 -0
  149. data/lib/easy_ml/core/model_evaluator.rb +161 -89
  150. data/lib/easy_ml/core/tuner/adapters/base_adapter.rb +28 -18
  151. data/lib/easy_ml/core/tuner/adapters/xgboost_adapter.rb +4 -25
  152. data/lib/easy_ml/core/tuner.rb +123 -62
  153. data/lib/easy_ml/core.rb +0 -3
  154. data/lib/easy_ml/core_ext/hash.rb +24 -0
  155. data/lib/easy_ml/core_ext/pathname.rb +11 -5
  156. data/lib/easy_ml/data/date_converter.rb +90 -0
  157. data/lib/easy_ml/data/filter_extensions.rb +31 -0
  158. data/lib/easy_ml/data/polars_column.rb +126 -0
  159. data/lib/easy_ml/data/polars_reader.rb +297 -0
  160. data/lib/easy_ml/data/preprocessor.rb +280 -142
  161. data/lib/easy_ml/data/simple_imputer.rb +255 -0
  162. data/lib/easy_ml/data/splits/file_split.rb +252 -0
  163. data/lib/easy_ml/data/splits/in_memory_split.rb +54 -0
  164. data/lib/easy_ml/data/splits/split.rb +95 -0
  165. data/lib/easy_ml/data/splits.rb +9 -0
  166. data/lib/easy_ml/data/statistics_learner.rb +93 -0
  167. data/lib/easy_ml/data/synced_directory.rb +341 -0
  168. data/lib/easy_ml/data.rb +6 -2
  169. data/lib/easy_ml/engine.rb +105 -6
  170. data/lib/easy_ml/feature_store.rb +227 -0
  171. data/lib/easy_ml/features.rb +61 -0
  172. data/lib/easy_ml/initializers/inflections.rb +17 -3
  173. data/lib/easy_ml/logging.rb +2 -2
  174. data/lib/easy_ml/predict.rb +74 -0
  175. data/lib/easy_ml/railtie/generators/migration/migration_generator.rb +192 -36
  176. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_column_histories.rb.tt +9 -0
  177. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_columns.rb.tt +25 -0
  178. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_dataset_histories.rb.tt +9 -0
  179. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_datasets.rb.tt +31 -0
  180. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_datasource_histories.rb.tt +9 -0
  181. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_datasources.rb.tt +16 -0
  182. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_deploys.rb.tt +24 -0
  183. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_events.rb.tt +20 -0
  184. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_feature_histories.rb.tt +14 -0
  185. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_features.rb.tt +32 -0
  186. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_model_file_histories.rb.tt +9 -0
  187. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_model_files.rb.tt +17 -0
  188. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_model_histories.rb.tt +9 -0
  189. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_models.rb.tt +20 -9
  190. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_predictions.rb.tt +17 -0
  191. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_retraining_jobs.rb.tt +77 -0
  192. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_settings.rb.tt +9 -0
  193. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_splitter_histories.rb.tt +9 -0
  194. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_splitters.rb.tt +15 -0
  195. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_tuner_jobs.rb.tt +40 -0
  196. data/lib/easy_ml/support/est.rb +5 -1
  197. data/lib/easy_ml/support/file_rotate.rb +79 -15
  198. data/lib/easy_ml/support/file_support.rb +9 -0
  199. data/lib/easy_ml/support/local_file.rb +24 -0
  200. data/lib/easy_ml/support/lockable.rb +62 -0
  201. data/lib/easy_ml/support/synced_file.rb +103 -0
  202. data/lib/easy_ml/support/utc.rb +5 -1
  203. data/lib/easy_ml/support.rb +6 -3
  204. data/lib/easy_ml/version.rb +4 -1
  205. data/lib/easy_ml.rb +7 -2
  206. metadata +355 -72
  207. data/app/models/easy_ml/models.rb +0 -5
  208. data/lib/easy_ml/core/model.rb +0 -30
  209. data/lib/easy_ml/core/model_core.rb +0 -181
  210. data/lib/easy_ml/core/models/hyperparameters/base.rb +0 -34
  211. data/lib/easy_ml/core/models/hyperparameters/xgboost.rb +0 -19
  212. data/lib/easy_ml/core/models/xgboost.rb +0 -10
  213. data/lib/easy_ml/core/models/xgboost_core.rb +0 -220
  214. data/lib/easy_ml/core/models.rb +0 -10
  215. data/lib/easy_ml/core/uploaders/model_uploader.rb +0 -24
  216. data/lib/easy_ml/core/uploaders.rb +0 -7
  217. data/lib/easy_ml/data/dataloader.rb +0 -6
  218. data/lib/easy_ml/data/dataset/data/preprocessor/statistics.json +0 -31
  219. data/lib/easy_ml/data/dataset/data/sample_info.json +0 -1
  220. data/lib/easy_ml/data/dataset/dataset/files/sample_info.json +0 -1
  221. data/lib/easy_ml/data/dataset/splits/file_split.rb +0 -140
  222. data/lib/easy_ml/data/dataset/splits/in_memory_split.rb +0 -49
  223. data/lib/easy_ml/data/dataset/splits/split.rb +0 -98
  224. data/lib/easy_ml/data/dataset/splits.rb +0 -11
  225. data/lib/easy_ml/data/dataset/splitters/date_splitter.rb +0 -43
  226. data/lib/easy_ml/data/dataset/splitters.rb +0 -9
  227. data/lib/easy_ml/data/dataset.rb +0 -430
  228. data/lib/easy_ml/data/datasource/datasource_factory.rb +0 -60
  229. data/lib/easy_ml/data/datasource/file_datasource.rb +0 -40
  230. data/lib/easy_ml/data/datasource/merged_datasource.rb +0 -64
  231. data/lib/easy_ml/data/datasource/polars_datasource.rb +0 -41
  232. data/lib/easy_ml/data/datasource/s3_datasource.rb +0 -89
  233. data/lib/easy_ml/data/datasource.rb +0 -33
  234. data/lib/easy_ml/data/preprocessor/preprocessor.rb +0 -205
  235. data/lib/easy_ml/data/preprocessor/simple_imputer.rb +0 -402
  236. data/lib/easy_ml/deployment.rb +0 -5
  237. data/lib/easy_ml/support/synced_directory.rb +0 -134
  238. data/lib/easy_ml/transforms.rb +0 -29
  239. /data/{lib/easy_ml/core → app/models/easy_ml}/models/hyperparameters.rb +0 -0
@@ -1,181 +0,0 @@
1
- require "carrierwave"
2
- require_relative "uploaders/model_uploader"
3
-
4
- module EasyML
5
- module Core
6
- module ModelCore
7
- attr_accessor :dataset
8
-
9
- def self.included(base)
10
- base.send(:include, GlueGun::DSL)
11
- base.send(:extend, CarrierWave::Mount)
12
- base.send(:mount_uploader, :file, EasyML::Core::Uploaders::ModelUploader)
13
-
14
- base.class_eval do
15
- validates :task, inclusion: { in: %w[regression classification] }
16
- validates :task, presence: true
17
- validate :dataset_is_a_dataset?
18
- validate :validate_any_metrics?
19
- validate :validate_metrics_for_task
20
- before_validation :save_model_file, if: -> { fit? }
21
- end
22
- end
23
-
24
- def fit(x_train: nil, y_train: nil, x_valid: nil, y_valid: nil)
25
- if x_train.nil?
26
- dataset.refresh!
27
- train_in_batches
28
- else
29
- train(x_train, y_train, x_valid, y_valid)
30
- end
31
- @is_fit = true
32
- end
33
-
34
- def predict(xs)
35
- raise NotImplementedError, "Subclasses must implement predict method"
36
- end
37
-
38
- def load
39
- raise NotImplementedError, "Subclasses must implement load method"
40
- end
41
-
42
- def _save_model_file
43
- raise NotImplementedError, "Subclasses must implement _save_model_file method"
44
- end
45
-
46
- def save
47
- super if defined?(super) && self.class.superclass.method_defined?(:save)
48
- save_model_file
49
- end
50
-
51
- def decode_labels(ys, col: nil)
52
- dataset.decode_labels(ys, col: col)
53
- end
54
-
55
- def evaluate(y_pred: nil, y_true: nil, x_true: nil, evaluator: nil)
56
- evaluator ||= self.evaluator
57
- EasyML::Core::ModelEvaluator.evaluate(model: self, y_pred: y_pred, y_true: y_true, x_true: x_true,
58
- evaluator: evaluator)
59
- end
60
-
61
- def save_model_file
62
- raise "No trained model! Need to train model before saving (call model.fit)" unless fit?
63
-
64
- path = File.join(model_dir, "#{version}.json")
65
- ensure_directory_exists(File.dirname(path))
66
-
67
- _save_model_file(path)
68
-
69
- File.open(path) do |f|
70
- self.file = f
71
- end
72
- file.store!
73
-
74
- cleanup
75
- end
76
-
77
- def get_params
78
- @hyperparameters.to_h
79
- end
80
-
81
- def allowed_metrics
82
- return [] unless task.present?
83
-
84
- case task.to_sym
85
- when :regression
86
- %w[mean_absolute_error mean_squared_error root_mean_squared_error r2_score]
87
- when :classification
88
- %w[accuracy_score precision_score recall_score f1_score auc roc_auc]
89
- else
90
- []
91
- end
92
- end
93
-
94
- def cleanup!
95
- [carrierwave_dir, model_dir].each do |dir|
96
- EasyML::FileRotate.new(dir, []).cleanup(extension_allowlist)
97
- end
98
- end
99
-
100
- def cleanup
101
- [carrierwave_dir, model_dir].each do |dir|
102
- EasyML::FileRotate.new(dir, files_to_keep).cleanup(extension_allowlist)
103
- end
104
- end
105
-
106
- def fit?
107
- @is_fit == true
108
- end
109
-
110
- private
111
-
112
- def carrierwave_dir
113
- return unless file.path.present?
114
-
115
- File.dirname(file.path).split("/")[0..-2].join("/")
116
- end
117
-
118
- def extension_allowlist
119
- EasyML::Core::Uploaders::ModelUploader.new.extension_allowlist
120
- end
121
-
122
- def _save_model_file(path = nil)
123
- raise NotImplementedError, "Subclasses must implement _save_model_file method"
124
- end
125
-
126
- def ensure_directory_exists(dir)
127
- FileUtils.mkdir_p(dir) unless File.directory?(dir)
128
- end
129
-
130
- def apply_defaults
131
- self.version ||= generate_version_string
132
- self.metrics ||= allowed_metrics
133
- self.ml_model ||= get_ml_model
134
- end
135
-
136
- def get_ml_model
137
- self.class.name.split("::").last.underscore
138
- end
139
-
140
- def generate_version_string
141
- timestamp = Time.now.utc.strftime("%Y%m%d%H%M%S")
142
- model_name = self.class.name.split("::").last.underscore
143
- "#{model_name}_#{timestamp}"
144
- end
145
-
146
- def model_dir
147
- File.join(root_dir, "easy_ml_models", name.present? ? name.split.join.underscore : "")
148
- end
149
-
150
- def files_to_keep
151
- Dir.glob(File.join(carrierwave_dir, "**/*")).select { |f| File.file?(f) }.sort_by do |filename|
152
- Time.parse(filename.split("/").last.gsub(/\D/, ""))
153
- end.reverse.take(5)
154
- end
155
-
156
- def dataset_is_a_dataset?
157
- return if dataset.nil?
158
- return if dataset.class.ancestors.include?(EasyML::Data::Dataset)
159
-
160
- errors.add(:dataset, "Must be a subclass of EasyML::Dataset")
161
- end
162
-
163
- def validate_any_metrics?
164
- return if metrics.any?
165
-
166
- errors.add(:metrics, "Must include at least one metric. Allowed metrics are #{allowed_metrics.join(", ")}")
167
- end
168
-
169
- def validate_metrics_for_task
170
- nonsensical_metrics = metrics.select do |metric|
171
- allowed_metrics.exclude?(metric)
172
- end
173
-
174
- return unless nonsensical_metrics.any?
175
-
176
- errors.add(:metrics,
177
- "cannot use metrics: #{nonsensical_metrics.join(", ")} for task #{task}. Allowed metrics are: #{allowed_metrics.join(", ")}")
178
- end
179
- end
180
- end
181
- end
@@ -1,34 +0,0 @@
1
- module EasyML
2
- module Models
3
- module Hyperparameters
4
- class Base
5
- include GlueGun::DSL
6
-
7
- attribute :learning_rate, :float, default: 0.01
8
- attribute :max_iterations, :integer, default: 100
9
- attribute :batch_size, :integer, default: 32
10
- attribute :regularization, :float, default: 0.0001
11
-
12
- def to_h
13
- attributes
14
- end
15
-
16
- def merge(other)
17
- return self if other.nil?
18
-
19
- other_hash = other.is_a?(Hyperparameters) ? other.to_h : other
20
- merged_hash = to_h.merge(other_hash)
21
- self.class.new(**merged_hash)
22
- end
23
-
24
- def [](key)
25
- send(key) if respond_to?(key)
26
- end
27
-
28
- def []=(key, value)
29
- send("#{key}=", value) if respond_to?("#{key}=")
30
- end
31
- end
32
- end
33
- end
34
- end
@@ -1,19 +0,0 @@
1
- module EasyML
2
- module Models
3
- module Hyperparameters
4
- class XGBoost < Base
5
- include GlueGun::DSL
6
-
7
- attribute :learning_rate, :float, default: 0.1
8
- attribute :max_depth, :integer, default: 6
9
- attribute :n_estimators, :integer, default: 100
10
- attribute :booster, :string, default: "gbtree"
11
- attribute :objective, :string, default: "reg:squarederror"
12
-
13
- validates :objective,
14
- inclusion: { in: %w[binary:logistic binary:hinge multi:softmax multi:softprob reg:squarederror
15
- reg:logistic] }
16
- end
17
- end
18
- end
19
- end
@@ -1,10 +0,0 @@
1
- require_relative "xgboost_core"
2
- module EasyML
3
- module Core
4
- module Models
5
- class XGBoost < EasyML::Core::Model
6
- include XGBoostCore
7
- end
8
- end
9
- end
10
- end
@@ -1,220 +0,0 @@
1
- require "wandb"
2
- module EasyML
3
- module Core
4
- module Models
5
- module XGBoostCore
6
- OBJECTIVES = {
7
- classification: {
8
- binary: %w[binary:logistic binary:hinge],
9
- multi_class: %w[multi:softmax multi:softprob]
10
- },
11
- regression: %w[reg:squarederror reg:logistic]
12
- }
13
-
14
- def self.included(base)
15
- base.class_eval do
16
- attribute :evaluator
17
-
18
- dependency :callbacks, { array: true } do |dep|
19
- dep.option :wandb do |opt|
20
- opt.set_class Wandb::XGBoostCallback
21
- opt.bind_attribute :log_model, default: false
22
- opt.bind_attribute :log_feature_importance, default: true
23
- opt.bind_attribute :importance_type, default: "gain"
24
- opt.bind_attribute :define_metric, default: true
25
- opt.bind_attribute :project_name
26
- end
27
- end
28
-
29
- dependency :hyperparameters do |dep|
30
- dep.set_class EasyML::Models::Hyperparameters::XGBoost
31
- dep.bind_attribute :batch_size, default: 32
32
- dep.bind_attribute :learning_rate, default: 1.1
33
- dep.bind_attribute :max_depth, default: 6
34
- dep.bind_attribute :n_estimators, default: 100
35
- dep.bind_attribute :booster, default: "gbtree"
36
- dep.bind_attribute :objective, default: "reg:squarederror"
37
- end
38
- end
39
- end
40
-
41
- attr_accessor :model, :booster
42
-
43
- def predict(xs)
44
- raise "No trained model! Train a model before calling predict" unless @booster.present?
45
- raise "Cannot predict on nil — XGBoost" if xs.nil?
46
-
47
- y_pred = @booster.predict(preprocess(xs))
48
-
49
- case task.to_sym
50
- when :classification
51
- to_classification(y_pred)
52
- else
53
- y_pred
54
- end
55
- end
56
-
57
- def predict_proba(data)
58
- dmat = DMatrix.new(data)
59
- y_pred = @booster.predict(dmat)
60
-
61
- if y_pred.first.is_a?(Array)
62
- # multiple classes
63
- y_pred
64
- else
65
- y_pred.map { |v| [1 - v, v] }
66
- end
67
- end
68
-
69
- def load(path = nil)
70
- path ||= file
71
- path = path&.file&.file if path.class.ancestors.include?(CarrierWave::Uploader::Base)
72
-
73
- raise "No existing model at #{path}" unless File.exist?(path)
74
-
75
- initialize_model do
76
- booster_class.new(params: hyperparameters.to_h, model_file: path)
77
- end
78
- end
79
-
80
- def _save_model_file(path)
81
- puts "XGBoost received path #{path}"
82
- @booster.save_model(path)
83
- end
84
-
85
- def feature_importances
86
- @model.booster.feature_names.zip(@model.feature_importances).to_h
87
- end
88
-
89
- def base_model
90
- ::XGBoost
91
- end
92
-
93
- def customize_callbacks
94
- yield callbacks
95
- end
96
-
97
- private
98
-
99
- def booster_class
100
- ::XGBoost::Booster
101
- end
102
-
103
- def d_matrix_class
104
- ::XGBoost::DMatrix
105
- end
106
-
107
- def model_class
108
- ::XGBoost::Model
109
- end
110
-
111
- def train
112
- validate_objective
113
-
114
- xs = xs.to_a.map(&:values)
115
- ys = ys.to_a.map(&:values)
116
- dtrain = d_matrix_class.new(xs, label: ys)
117
- @model = base_model.train(hyperparameters.to_h, dtrain, callbacks: callbacks)
118
- end
119
-
120
- def train_in_batches
121
- validate_objective
122
-
123
- # Initialize the model with the first batch
124
- @model = nil
125
- @booster = nil
126
- x_valid, y_valid = dataset.valid(split_ys: true)
127
- x_train, y_train = dataset.train(split_ys: true)
128
- fit_batch(x_train, y_train, x_valid, y_valid)
129
- end
130
-
131
- def _preprocess(df)
132
- df.to_a.map do |row|
133
- row.values.map do |value|
134
- case value
135
- when Time
136
- value.to_i # Convert Time to Unix timestamp
137
- when Date
138
- value.to_time.to_i # Convert Date to Unix timestamp
139
- when String
140
- value
141
- when TrueClass, FalseClass
142
- value ? 1.0 : 0.0 # Convert booleans to 1.0 and 0.0
143
- when Integer
144
- value
145
- else
146
- value.to_f # Ensure everything else is converted to a float
147
- end
148
- end
149
- end
150
- end
151
-
152
- def preprocess(xs, ys = nil)
153
- column_names = xs.columns
154
- xs = _preprocess(xs)
155
- ys = ys.nil? ? nil : _preprocess(ys).flatten
156
- kwargs = { label: ys }.compact
157
- ::XGBoost::DMatrix.new(xs, **kwargs).tap do |dmat|
158
- dmat.feature_names = column_names
159
- end
160
- end
161
-
162
- def initialize_model
163
- @model = model_class.new(n_estimators: @hyperparameters.to_h.dig(:n_estimators))
164
- @booster = yield
165
- @model.instance_variable_set(:@booster, @booster)
166
- end
167
-
168
- def validate_objective
169
- objective = hyperparameters.objective
170
- unless task.present?
171
- raise ArgumentError,
172
- "cannot train model without task. Please specify either regression or classification (model.task = :regression)"
173
- end
174
-
175
- case task.to_sym
176
- when :classification
177
- _, ys = dataset.data(split_ys: true)
178
- classification_type = ys[ys.columns.first].uniq.count <= 2 ? :binary : :multi_class
179
- allowed_objectives = OBJECTIVES[:classification][classification_type]
180
- else
181
- allowed_objectives = OBJECTIVES[task.to_sym]
182
- end
183
- return if allowed_objectives.map(&:to_sym).include?(objective.to_sym)
184
-
185
- raise ArgumentError,
186
- "cannot use #{objective} for #{task} task. Allowed objectives are: #{allowed_objectives.join(", ")}"
187
- end
188
-
189
- def fit_batch(x_train, y_train, x_valid, y_valid)
190
- d_train = preprocess(x_train, y_train)
191
- d_valid = preprocess(x_valid, y_valid)
192
-
193
- evals = [[d_train, "train"], [d_valid, "eval"]]
194
-
195
- # # If this is the first batch, create the booster
196
- if @booster.nil?
197
- initialize_model do
198
- base_model.train(@hyperparameters.to_h, d_train,
199
- num_boost_round: @hyperparameters.to_h.dig("n_estimators"), evals: evals, callbacks: callbacks)
200
- end
201
- else
202
- # Update the existing booster with the new batch
203
- @model.update(d_train)
204
- end
205
- end
206
-
207
- def to_classification(y_pred)
208
- if y_pred.first.is_a?(Array)
209
- # multiple classes
210
- y_pred.map do |v|
211
- v.map.with_index.max_by { |v2, _| v2 }.last
212
- end
213
- else
214
- y_pred.map { |v| v > 0.5 ? 1 : 0 }
215
- end
216
- end
217
- end
218
- end
219
- end
220
- end
@@ -1,10 +0,0 @@
1
- module EasyML
2
- module Core
3
- module Models
4
- require_relative "models/hyperparameters"
5
- require_relative "models/xgboost"
6
-
7
- AVAILABLE_MODELS = [XGBoost]
8
- end
9
- end
10
- end
@@ -1,24 +0,0 @@
1
- require "carrierwave"
2
-
3
- module EasyML
4
- module Core
5
- module Uploaders
6
- class ModelUploader < CarrierWave::Uploader::Base
7
- # Choose storage type
8
- if Rails.env.production?
9
- storage :fog
10
- else
11
- storage :file
12
- end
13
-
14
- def store_dir
15
- "easy_ml_models/#{model.name}"
16
- end
17
-
18
- def extension_allowlist
19
- %w[bin model json]
20
- end
21
- end
22
- end
23
- end
24
- end
@@ -1,7 +0,0 @@
1
- module EasyML
2
- module Core
3
- module Uploaders
4
- require_relative "uploaders/model_uploader"
5
- end
6
- end
7
- end
@@ -1,6 +0,0 @@
1
- module ML
2
- module Data
3
- class Dataloader
4
- end
5
- end
6
- end
@@ -1,31 +0,0 @@
1
- {
2
- "annual_revenue": {
3
- "median": {
4
- "value": 3000.0,
5
- "original_dtype": {
6
- "__type__": "polars_dtype",
7
- "value": "Polars::Int64"
8
- }
9
- }
10
- },
11
- "loan_purpose": {
12
- "categorical": {
13
- "value": {
14
- "payroll": 4,
15
- "expansion": 1
16
- },
17
- "label_encoder": {
18
- "expansion": 0,
19
- "payroll": 1
20
- },
21
- "label_decoder": {
22
- "0": "expansion",
23
- "1": "payroll"
24
- },
25
- "original_dtype": {
26
- "__type__": "polars_dtype",
27
- "value": "Polars::String"
28
- }
29
- }
30
- }
31
- }
@@ -1 +0,0 @@
1
- {"previous_sample":1.0}
@@ -1 +0,0 @@
1
- {"previous_sample":1.0}
@@ -1,140 +0,0 @@
1
- require_relative "split"
2
-
3
- module EasyML
4
- module Data
5
- class Dataset
6
- module Splits
7
- class FileSplit < Split
8
- include GlueGun::DSL
9
- include EasyML::Data::Utils
10
-
11
- attribute :dir, :string
12
- attribute :polars_args, :hash, default: {}
13
- attribute :max_rows_per_file, :integer, default: 1_000_000
14
- attribute :batch_size, :integer, default: 10_000
15
- attribute :sample, :float, default: 1.0
16
- attribute :verbose, :boolean, default: false
17
-
18
- def initialize(options)
19
- super
20
- FileUtils.mkdir_p(dir)
21
- end
22
-
23
- def save(segment, df)
24
- segment_dir = File.join(dir, segment.to_s)
25
- FileUtils.mkdir_p(segment_dir)
26
-
27
- current_file = current_file_for_segment(segment)
28
- current_row_count = current_file && File.exist?(current_file) ? df(current_file).shape[0] : 0
29
- remaining_rows = max_rows_per_file - current_row_count
30
-
31
- while df.shape[0] > 0
32
- if df.shape[0] <= remaining_rows
33
- append_to_csv(df, current_file)
34
- break
35
- else
36
- df_to_append = df.slice(0, remaining_rows)
37
- df = df.slice(remaining_rows, df.shape[0] - remaining_rows)
38
- append_to_csv(df_to_append, current_file)
39
- current_file = new_file_path_for_segment(segment)
40
- remaining_rows = max_rows_per_file
41
- end
42
- end
43
- end
44
-
45
- def read(segment, split_ys: false, target: nil, drop_cols: [], &block)
46
- files = files_for_segment(segment)
47
-
48
- if block_given?
49
- result = nil
50
- total_rows = files.sum { |file| df(file).shape[0] }
51
- progress_bar = create_progress_bar(segment, total_rows) if verbose
52
-
53
- files.each do |file|
54
- df = self.df(file)
55
- df = sample_data(df) if sample < 1.0
56
- drop_cols &= df.columns
57
- df = df.drop(drop_cols) unless drop_cols.empty?
58
-
59
- if split_ys
60
- xs, ys = split_features_targets(df, true, target)
61
- result = process_block_with_split_ys(block, result, xs, ys)
62
- else
63
- result = process_block_without_split_ys(block, result, df)
64
- end
65
-
66
- progress_bar.progress += df.shape[0] if verbose
67
- end
68
- progress_bar.finish if verbose
69
- result
70
- elsif files.empty?
71
- return nil, nil if split_ys
72
-
73
- nil
74
-
75
- else
76
- combined_df = combine_dataframes(files)
77
- combined_df = sample_data(combined_df) if sample < 1.0
78
- drop_cols &= combined_df.columns
79
- combined_df = combined_df.drop(drop_cols) unless drop_cols.empty?
80
- split_features_targets(combined_df, split_ys, target)
81
- end
82
- end
83
-
84
- def cleanup
85
- FileUtils.rm_rf(dir)
86
- FileUtils.mkdir_p(dir)
87
- end
88
-
89
- def split_at
90
- return nil if output_files.empty?
91
-
92
- output_files.map { |file| File.mtime(file) }.max
93
- end
94
-
95
- private
96
-
97
- def read_csv_batched(path)
98
- Polars.read_csv_batched(path, batch_size: batch_size, **polars_args)
99
- end
100
-
101
- def df(path)
102
- Polars.read_csv(path, **polars_args)
103
- end
104
-
105
- def output_files
106
- Dir.glob("#{dir}/**/*.csv")
107
- end
108
-
109
- def files_for_segment(segment)
110
- segment_dir = File.join(dir, segment.to_s)
111
- Dir.glob(File.join(segment_dir, "**/*.csv")).sort
112
- end
113
-
114
- def current_file_for_segment(segment)
115
- current_file = files_for_segment(segment).last
116
- return new_file_path_for_segment(segment) if current_file.nil?
117
-
118
- row_count = df(current_file).shape[0]
119
- if row_count >= max_rows_per_file
120
- new_file_path_for_segment(segment)
121
- else
122
- current_file
123
- end
124
- end
125
-
126
- def new_file_path_for_segment(segment)
127
- segment_dir = File.join(dir, segment.to_s)
128
- file_number = Dir.glob(File.join(segment_dir, "*.csv")).count
129
- File.join(segment_dir, "#{segment}_%04d.csv" % file_number)
130
- end
131
-
132
- def combine_dataframes(files)
133
- dfs = files.map { |file| df(file) }
134
- Polars.concat(dfs)
135
- end
136
- end
137
- end
138
- end
139
- end
140
- end