easy_ml 0.1.4 → 0.2.0.pre.rc1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/README.md +234 -26
- data/Rakefile +45 -0
- data/app/controllers/easy_ml/application_controller.rb +67 -0
- data/app/controllers/easy_ml/columns_controller.rb +38 -0
- data/app/controllers/easy_ml/datasets_controller.rb +156 -0
- data/app/controllers/easy_ml/datasources_controller.rb +88 -0
- data/app/controllers/easy_ml/deploys_controller.rb +20 -0
- data/app/controllers/easy_ml/models_controller.rb +151 -0
- data/app/controllers/easy_ml/retraining_runs_controller.rb +19 -0
- data/app/controllers/easy_ml/settings_controller.rb +59 -0
- data/app/frontend/components/AlertProvider.tsx +108 -0
- data/app/frontend/components/DatasetPreview.tsx +161 -0
- data/app/frontend/components/EmptyState.tsx +28 -0
- data/app/frontend/components/ModelCard.tsx +255 -0
- data/app/frontend/components/ModelDetails.tsx +334 -0
- data/app/frontend/components/ModelForm.tsx +384 -0
- data/app/frontend/components/Navigation.tsx +300 -0
- data/app/frontend/components/Pagination.tsx +72 -0
- data/app/frontend/components/Popover.tsx +55 -0
- data/app/frontend/components/PredictionStream.tsx +105 -0
- data/app/frontend/components/ScheduleModal.tsx +726 -0
- data/app/frontend/components/SearchInput.tsx +23 -0
- data/app/frontend/components/SearchableSelect.tsx +132 -0
- data/app/frontend/components/dataset/AutosaveIndicator.tsx +39 -0
- data/app/frontend/components/dataset/ColumnConfigModal.tsx +431 -0
- data/app/frontend/components/dataset/ColumnFilters.tsx +256 -0
- data/app/frontend/components/dataset/ColumnList.tsx +101 -0
- data/app/frontend/components/dataset/FeatureConfigPopover.tsx +57 -0
- data/app/frontend/components/dataset/FeaturePicker.tsx +205 -0
- data/app/frontend/components/dataset/PreprocessingConfig.tsx +704 -0
- data/app/frontend/components/dataset/SplitConfigurator.tsx +120 -0
- data/app/frontend/components/dataset/splitters/DateSplitter.tsx +58 -0
- data/app/frontend/components/dataset/splitters/KFoldSplitter.tsx +68 -0
- data/app/frontend/components/dataset/splitters/LeavePOutSplitter.tsx +29 -0
- data/app/frontend/components/dataset/splitters/PredefinedSplitter.tsx +146 -0
- data/app/frontend/components/dataset/splitters/RandomSplitter.tsx +85 -0
- data/app/frontend/components/dataset/splitters/StratifiedSplitter.tsx +79 -0
- data/app/frontend/components/dataset/splitters/constants.ts +77 -0
- data/app/frontend/components/dataset/splitters/types.ts +168 -0
- data/app/frontend/components/dataset/splitters/utils.ts +53 -0
- data/app/frontend/components/features/CodeEditor.tsx +46 -0
- data/app/frontend/components/features/DataPreview.tsx +150 -0
- data/app/frontend/components/features/FeatureCard.tsx +88 -0
- data/app/frontend/components/features/FeatureForm.tsx +235 -0
- data/app/frontend/components/features/FeatureGroupCard.tsx +54 -0
- data/app/frontend/components/settings/PluginSettings.tsx +81 -0
- data/app/frontend/components/ui/badge.tsx +44 -0
- data/app/frontend/components/ui/collapsible.tsx +9 -0
- data/app/frontend/components/ui/scroll-area.tsx +46 -0
- data/app/frontend/components/ui/separator.tsx +29 -0
- data/app/frontend/entrypoints/App.tsx +40 -0
- data/app/frontend/entrypoints/Application.tsx +24 -0
- data/app/frontend/hooks/useAutosave.ts +61 -0
- data/app/frontend/layouts/Layout.tsx +38 -0
- data/app/frontend/lib/utils.ts +6 -0
- data/app/frontend/mockData.ts +272 -0
- data/app/frontend/pages/DatasetDetailsPage.tsx +103 -0
- data/app/frontend/pages/DatasetsPage.tsx +261 -0
- data/app/frontend/pages/DatasourceFormPage.tsx +147 -0
- data/app/frontend/pages/DatasourcesPage.tsx +261 -0
- data/app/frontend/pages/EditModelPage.tsx +45 -0
- data/app/frontend/pages/EditTransformationPage.tsx +56 -0
- data/app/frontend/pages/ModelsPage.tsx +115 -0
- data/app/frontend/pages/NewDatasetPage.tsx +366 -0
- data/app/frontend/pages/NewModelPage.tsx +45 -0
- data/app/frontend/pages/NewTransformationPage.tsx +43 -0
- data/app/frontend/pages/SettingsPage.tsx +272 -0
- data/app/frontend/pages/ShowModelPage.tsx +30 -0
- data/app/frontend/pages/TransformationsPage.tsx +95 -0
- data/app/frontend/styles/application.css +100 -0
- data/app/frontend/types/dataset.ts +146 -0
- data/app/frontend/types/datasource.ts +33 -0
- data/app/frontend/types/preprocessing.ts +1 -0
- data/app/frontend/types.ts +113 -0
- data/app/helpers/easy_ml/application_helper.rb +10 -0
- data/app/jobs/easy_ml/application_job.rb +21 -0
- data/app/jobs/easy_ml/batch_job.rb +46 -0
- data/app/jobs/easy_ml/compute_feature_job.rb +19 -0
- data/app/jobs/easy_ml/deploy_job.rb +13 -0
- data/app/jobs/easy_ml/finalize_feature_job.rb +15 -0
- data/app/jobs/easy_ml/refresh_dataset_job.rb +32 -0
- data/app/jobs/easy_ml/schedule_retraining_job.rb +11 -0
- data/app/jobs/easy_ml/sync_datasource_job.rb +17 -0
- data/app/jobs/easy_ml/training_job.rb +62 -0
- data/app/models/easy_ml/adapters/base_adapter.rb +45 -0
- data/app/models/easy_ml/adapters/polars_adapter.rb +77 -0
- data/app/models/easy_ml/cleaner.rb +82 -0
- data/app/models/easy_ml/column.rb +124 -0
- data/app/models/easy_ml/column_history.rb +30 -0
- data/app/models/easy_ml/column_list.rb +122 -0
- data/app/models/easy_ml/concerns/configurable.rb +61 -0
- data/app/models/easy_ml/concerns/versionable.rb +19 -0
- data/app/models/easy_ml/dataset.rb +767 -0
- data/app/models/easy_ml/dataset_history.rb +56 -0
- data/app/models/easy_ml/datasource.rb +182 -0
- data/app/models/easy_ml/datasource_history.rb +24 -0
- data/app/models/easy_ml/datasources/base_datasource.rb +54 -0
- data/app/models/easy_ml/datasources/file_datasource.rb +58 -0
- data/app/models/easy_ml/datasources/polars_datasource.rb +89 -0
- data/app/models/easy_ml/datasources/s3_datasource.rb +97 -0
- data/app/models/easy_ml/deploy.rb +114 -0
- data/app/models/easy_ml/event.rb +79 -0
- data/app/models/easy_ml/feature.rb +437 -0
- data/app/models/easy_ml/feature_history.rb +38 -0
- data/app/models/easy_ml/model.rb +575 -41
- data/app/models/easy_ml/model_file.rb +133 -0
- data/app/models/easy_ml/model_file_history.rb +24 -0
- data/app/models/easy_ml/model_history.rb +51 -0
- data/app/models/easy_ml/models/base_model.rb +58 -0
- data/app/models/easy_ml/models/hyperparameters/base.rb +99 -0
- data/app/models/easy_ml/models/hyperparameters/xgboost/dart.rb +82 -0
- data/app/models/easy_ml/models/hyperparameters/xgboost/gblinear.rb +82 -0
- data/app/models/easy_ml/models/hyperparameters/xgboost/gbtree.rb +97 -0
- data/app/models/easy_ml/models/hyperparameters/xgboost.rb +71 -0
- data/app/models/easy_ml/models/xgboost/evals_callback.rb +138 -0
- data/app/models/easy_ml/models/xgboost/progress_callback.rb +39 -0
- data/app/models/easy_ml/models/xgboost.rb +544 -5
- data/app/models/easy_ml/prediction.rb +44 -0
- data/app/models/easy_ml/retraining_job.rb +278 -0
- data/app/models/easy_ml/retraining_run.rb +184 -0
- data/app/models/easy_ml/settings.rb +37 -0
- data/app/models/easy_ml/splitter.rb +90 -0
- data/app/models/easy_ml/splitters/base_splitter.rb +28 -0
- data/app/models/easy_ml/splitters/date_splitter.rb +91 -0
- data/app/models/easy_ml/splitters/predefined_splitter.rb +74 -0
- data/app/models/easy_ml/splitters/random_splitter.rb +82 -0
- data/app/models/easy_ml/tuner_job.rb +56 -0
- data/app/models/easy_ml/tuner_run.rb +31 -0
- data/app/models/splitter_history.rb +6 -0
- data/app/serializers/easy_ml/column_serializer.rb +27 -0
- data/app/serializers/easy_ml/dataset_serializer.rb +73 -0
- data/app/serializers/easy_ml/datasource_serializer.rb +64 -0
- data/app/serializers/easy_ml/feature_serializer.rb +27 -0
- data/app/serializers/easy_ml/model_serializer.rb +90 -0
- data/app/serializers/easy_ml/retraining_job_serializer.rb +22 -0
- data/app/serializers/easy_ml/retraining_run_serializer.rb +39 -0
- data/app/serializers/easy_ml/settings_serializer.rb +9 -0
- data/app/views/layouts/easy_ml/application.html.erb +15 -0
- data/config/initializers/resque.rb +3 -0
- data/config/resque-pool.yml +6 -0
- data/config/routes.rb +39 -0
- data/config/spring.rb +1 -0
- data/config/vite.json +15 -0
- data/lib/easy_ml/configuration.rb +64 -0
- data/lib/easy_ml/core/evaluators/base_evaluator.rb +53 -0
- data/lib/easy_ml/core/evaluators/classification_evaluators.rb +126 -0
- data/lib/easy_ml/core/evaluators/regression_evaluators.rb +66 -0
- data/lib/easy_ml/core/model_evaluator.rb +161 -89
- data/lib/easy_ml/core/tuner/adapters/base_adapter.rb +28 -18
- data/lib/easy_ml/core/tuner/adapters/xgboost_adapter.rb +4 -25
- data/lib/easy_ml/core/tuner.rb +123 -62
- data/lib/easy_ml/core.rb +0 -3
- data/lib/easy_ml/core_ext/hash.rb +24 -0
- data/lib/easy_ml/core_ext/pathname.rb +11 -5
- data/lib/easy_ml/data/date_converter.rb +90 -0
- data/lib/easy_ml/data/filter_extensions.rb +31 -0
- data/lib/easy_ml/data/polars_column.rb +126 -0
- data/lib/easy_ml/data/polars_reader.rb +297 -0
- data/lib/easy_ml/data/preprocessor.rb +280 -142
- data/lib/easy_ml/data/simple_imputer.rb +255 -0
- data/lib/easy_ml/data/splits/file_split.rb +252 -0
- data/lib/easy_ml/data/splits/in_memory_split.rb +54 -0
- data/lib/easy_ml/data/splits/split.rb +95 -0
- data/lib/easy_ml/data/splits.rb +9 -0
- data/lib/easy_ml/data/statistics_learner.rb +93 -0
- data/lib/easy_ml/data/synced_directory.rb +341 -0
- data/lib/easy_ml/data.rb +6 -2
- data/lib/easy_ml/engine.rb +105 -6
- data/lib/easy_ml/feature_store.rb +227 -0
- data/lib/easy_ml/features.rb +61 -0
- data/lib/easy_ml/initializers/inflections.rb +17 -3
- data/lib/easy_ml/logging.rb +2 -2
- data/lib/easy_ml/predict.rb +74 -0
- data/lib/easy_ml/railtie/generators/migration/migration_generator.rb +192 -36
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_column_histories.rb.tt +9 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_columns.rb.tt +25 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_dataset_histories.rb.tt +9 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_datasets.rb.tt +31 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_datasource_histories.rb.tt +9 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_datasources.rb.tt +16 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_deploys.rb.tt +24 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_events.rb.tt +20 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_feature_histories.rb.tt +14 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_features.rb.tt +32 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_model_file_histories.rb.tt +9 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_model_files.rb.tt +17 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_model_histories.rb.tt +9 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_models.rb.tt +20 -9
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_predictions.rb.tt +17 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_retraining_jobs.rb.tt +77 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_settings.rb.tt +9 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_splitter_histories.rb.tt +9 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_splitters.rb.tt +15 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_tuner_jobs.rb.tt +40 -0
- data/lib/easy_ml/support/est.rb +5 -1
- data/lib/easy_ml/support/file_rotate.rb +79 -15
- data/lib/easy_ml/support/file_support.rb +9 -0
- data/lib/easy_ml/support/local_file.rb +24 -0
- data/lib/easy_ml/support/lockable.rb +62 -0
- data/lib/easy_ml/support/synced_file.rb +103 -0
- data/lib/easy_ml/support/utc.rb +5 -1
- data/lib/easy_ml/support.rb +6 -3
- data/lib/easy_ml/version.rb +4 -1
- data/lib/easy_ml.rb +7 -2
- metadata +355 -72
- data/app/models/easy_ml/models.rb +0 -5
- data/lib/easy_ml/core/model.rb +0 -30
- data/lib/easy_ml/core/model_core.rb +0 -181
- data/lib/easy_ml/core/models/hyperparameters/base.rb +0 -34
- data/lib/easy_ml/core/models/hyperparameters/xgboost.rb +0 -19
- data/lib/easy_ml/core/models/xgboost.rb +0 -10
- data/lib/easy_ml/core/models/xgboost_core.rb +0 -220
- data/lib/easy_ml/core/models.rb +0 -10
- data/lib/easy_ml/core/uploaders/model_uploader.rb +0 -24
- data/lib/easy_ml/core/uploaders.rb +0 -7
- data/lib/easy_ml/data/dataloader.rb +0 -6
- data/lib/easy_ml/data/dataset/data/preprocessor/statistics.json +0 -31
- data/lib/easy_ml/data/dataset/data/sample_info.json +0 -1
- data/lib/easy_ml/data/dataset/dataset/files/sample_info.json +0 -1
- data/lib/easy_ml/data/dataset/splits/file_split.rb +0 -140
- data/lib/easy_ml/data/dataset/splits/in_memory_split.rb +0 -49
- data/lib/easy_ml/data/dataset/splits/split.rb +0 -98
- data/lib/easy_ml/data/dataset/splits.rb +0 -11
- data/lib/easy_ml/data/dataset/splitters/date_splitter.rb +0 -43
- data/lib/easy_ml/data/dataset/splitters.rb +0 -9
- data/lib/easy_ml/data/dataset.rb +0 -430
- data/lib/easy_ml/data/datasource/datasource_factory.rb +0 -60
- data/lib/easy_ml/data/datasource/file_datasource.rb +0 -40
- data/lib/easy_ml/data/datasource/merged_datasource.rb +0 -64
- data/lib/easy_ml/data/datasource/polars_datasource.rb +0 -41
- data/lib/easy_ml/data/datasource/s3_datasource.rb +0 -89
- data/lib/easy_ml/data/datasource.rb +0 -33
- data/lib/easy_ml/data/preprocessor/preprocessor.rb +0 -205
- data/lib/easy_ml/data/preprocessor/simple_imputer.rb +0 -402
- data/lib/easy_ml/deployment.rb +0 -5
- data/lib/easy_ml/support/synced_directory.rb +0 -134
- data/lib/easy_ml/transforms.rb +0 -29
- /data/{lib/easy_ml/core → app/models/easy_ml}/models/hyperparameters.rb +0 -0
@@ -1,181 +0,0 @@
|
|
1
|
-
require "carrierwave"
|
2
|
-
require_relative "uploaders/model_uploader"
|
3
|
-
|
4
|
-
module EasyML
|
5
|
-
module Core
|
6
|
-
module ModelCore
|
7
|
-
attr_accessor :dataset
|
8
|
-
|
9
|
-
def self.included(base)
|
10
|
-
base.send(:include, GlueGun::DSL)
|
11
|
-
base.send(:extend, CarrierWave::Mount)
|
12
|
-
base.send(:mount_uploader, :file, EasyML::Core::Uploaders::ModelUploader)
|
13
|
-
|
14
|
-
base.class_eval do
|
15
|
-
validates :task, inclusion: { in: %w[regression classification] }
|
16
|
-
validates :task, presence: true
|
17
|
-
validate :dataset_is_a_dataset?
|
18
|
-
validate :validate_any_metrics?
|
19
|
-
validate :validate_metrics_for_task
|
20
|
-
before_validation :save_model_file, if: -> { fit? }
|
21
|
-
end
|
22
|
-
end
|
23
|
-
|
24
|
-
def fit(x_train: nil, y_train: nil, x_valid: nil, y_valid: nil)
|
25
|
-
if x_train.nil?
|
26
|
-
dataset.refresh!
|
27
|
-
train_in_batches
|
28
|
-
else
|
29
|
-
train(x_train, y_train, x_valid, y_valid)
|
30
|
-
end
|
31
|
-
@is_fit = true
|
32
|
-
end
|
33
|
-
|
34
|
-
def predict(xs)
|
35
|
-
raise NotImplementedError, "Subclasses must implement predict method"
|
36
|
-
end
|
37
|
-
|
38
|
-
def load
|
39
|
-
raise NotImplementedError, "Subclasses must implement load method"
|
40
|
-
end
|
41
|
-
|
42
|
-
def _save_model_file
|
43
|
-
raise NotImplementedError, "Subclasses must implement _save_model_file method"
|
44
|
-
end
|
45
|
-
|
46
|
-
def save
|
47
|
-
super if defined?(super) && self.class.superclass.method_defined?(:save)
|
48
|
-
save_model_file
|
49
|
-
end
|
50
|
-
|
51
|
-
def decode_labels(ys, col: nil)
|
52
|
-
dataset.decode_labels(ys, col: col)
|
53
|
-
end
|
54
|
-
|
55
|
-
def evaluate(y_pred: nil, y_true: nil, x_true: nil, evaluator: nil)
|
56
|
-
evaluator ||= self.evaluator
|
57
|
-
EasyML::Core::ModelEvaluator.evaluate(model: self, y_pred: y_pred, y_true: y_true, x_true: x_true,
|
58
|
-
evaluator: evaluator)
|
59
|
-
end
|
60
|
-
|
61
|
-
def save_model_file
|
62
|
-
raise "No trained model! Need to train model before saving (call model.fit)" unless fit?
|
63
|
-
|
64
|
-
path = File.join(model_dir, "#{version}.json")
|
65
|
-
ensure_directory_exists(File.dirname(path))
|
66
|
-
|
67
|
-
_save_model_file(path)
|
68
|
-
|
69
|
-
File.open(path) do |f|
|
70
|
-
self.file = f
|
71
|
-
end
|
72
|
-
file.store!
|
73
|
-
|
74
|
-
cleanup
|
75
|
-
end
|
76
|
-
|
77
|
-
def get_params
|
78
|
-
@hyperparameters.to_h
|
79
|
-
end
|
80
|
-
|
81
|
-
def allowed_metrics
|
82
|
-
return [] unless task.present?
|
83
|
-
|
84
|
-
case task.to_sym
|
85
|
-
when :regression
|
86
|
-
%w[mean_absolute_error mean_squared_error root_mean_squared_error r2_score]
|
87
|
-
when :classification
|
88
|
-
%w[accuracy_score precision_score recall_score f1_score auc roc_auc]
|
89
|
-
else
|
90
|
-
[]
|
91
|
-
end
|
92
|
-
end
|
93
|
-
|
94
|
-
def cleanup!
|
95
|
-
[carrierwave_dir, model_dir].each do |dir|
|
96
|
-
EasyML::FileRotate.new(dir, []).cleanup(extension_allowlist)
|
97
|
-
end
|
98
|
-
end
|
99
|
-
|
100
|
-
def cleanup
|
101
|
-
[carrierwave_dir, model_dir].each do |dir|
|
102
|
-
EasyML::FileRotate.new(dir, files_to_keep).cleanup(extension_allowlist)
|
103
|
-
end
|
104
|
-
end
|
105
|
-
|
106
|
-
def fit?
|
107
|
-
@is_fit == true
|
108
|
-
end
|
109
|
-
|
110
|
-
private
|
111
|
-
|
112
|
-
def carrierwave_dir
|
113
|
-
return unless file.path.present?
|
114
|
-
|
115
|
-
File.dirname(file.path).split("/")[0..-2].join("/")
|
116
|
-
end
|
117
|
-
|
118
|
-
def extension_allowlist
|
119
|
-
EasyML::Core::Uploaders::ModelUploader.new.extension_allowlist
|
120
|
-
end
|
121
|
-
|
122
|
-
def _save_model_file(path = nil)
|
123
|
-
raise NotImplementedError, "Subclasses must implement _save_model_file method"
|
124
|
-
end
|
125
|
-
|
126
|
-
def ensure_directory_exists(dir)
|
127
|
-
FileUtils.mkdir_p(dir) unless File.directory?(dir)
|
128
|
-
end
|
129
|
-
|
130
|
-
def apply_defaults
|
131
|
-
self.version ||= generate_version_string
|
132
|
-
self.metrics ||= allowed_metrics
|
133
|
-
self.ml_model ||= get_ml_model
|
134
|
-
end
|
135
|
-
|
136
|
-
def get_ml_model
|
137
|
-
self.class.name.split("::").last.underscore
|
138
|
-
end
|
139
|
-
|
140
|
-
def generate_version_string
|
141
|
-
timestamp = Time.now.utc.strftime("%Y%m%d%H%M%S")
|
142
|
-
model_name = self.class.name.split("::").last.underscore
|
143
|
-
"#{model_name}_#{timestamp}"
|
144
|
-
end
|
145
|
-
|
146
|
-
def model_dir
|
147
|
-
File.join(root_dir, "easy_ml_models", name.present? ? name.split.join.underscore : "")
|
148
|
-
end
|
149
|
-
|
150
|
-
def files_to_keep
|
151
|
-
Dir.glob(File.join(carrierwave_dir, "**/*")).select { |f| File.file?(f) }.sort_by do |filename|
|
152
|
-
Time.parse(filename.split("/").last.gsub(/\D/, ""))
|
153
|
-
end.reverse.take(5)
|
154
|
-
end
|
155
|
-
|
156
|
-
def dataset_is_a_dataset?
|
157
|
-
return if dataset.nil?
|
158
|
-
return if dataset.class.ancestors.include?(EasyML::Data::Dataset)
|
159
|
-
|
160
|
-
errors.add(:dataset, "Must be a subclass of EasyML::Dataset")
|
161
|
-
end
|
162
|
-
|
163
|
-
def validate_any_metrics?
|
164
|
-
return if metrics.any?
|
165
|
-
|
166
|
-
errors.add(:metrics, "Must include at least one metric. Allowed metrics are #{allowed_metrics.join(", ")}")
|
167
|
-
end
|
168
|
-
|
169
|
-
def validate_metrics_for_task
|
170
|
-
nonsensical_metrics = metrics.select do |metric|
|
171
|
-
allowed_metrics.exclude?(metric)
|
172
|
-
end
|
173
|
-
|
174
|
-
return unless nonsensical_metrics.any?
|
175
|
-
|
176
|
-
errors.add(:metrics,
|
177
|
-
"cannot use metrics: #{nonsensical_metrics.join(", ")} for task #{task}. Allowed metrics are: #{allowed_metrics.join(", ")}")
|
178
|
-
end
|
179
|
-
end
|
180
|
-
end
|
181
|
-
end
|
@@ -1,34 +0,0 @@
|
|
1
|
-
module EasyML
|
2
|
-
module Models
|
3
|
-
module Hyperparameters
|
4
|
-
class Base
|
5
|
-
include GlueGun::DSL
|
6
|
-
|
7
|
-
attribute :learning_rate, :float, default: 0.01
|
8
|
-
attribute :max_iterations, :integer, default: 100
|
9
|
-
attribute :batch_size, :integer, default: 32
|
10
|
-
attribute :regularization, :float, default: 0.0001
|
11
|
-
|
12
|
-
def to_h
|
13
|
-
attributes
|
14
|
-
end
|
15
|
-
|
16
|
-
def merge(other)
|
17
|
-
return self if other.nil?
|
18
|
-
|
19
|
-
other_hash = other.is_a?(Hyperparameters) ? other.to_h : other
|
20
|
-
merged_hash = to_h.merge(other_hash)
|
21
|
-
self.class.new(**merged_hash)
|
22
|
-
end
|
23
|
-
|
24
|
-
def [](key)
|
25
|
-
send(key) if respond_to?(key)
|
26
|
-
end
|
27
|
-
|
28
|
-
def []=(key, value)
|
29
|
-
send("#{key}=", value) if respond_to?("#{key}=")
|
30
|
-
end
|
31
|
-
end
|
32
|
-
end
|
33
|
-
end
|
34
|
-
end
|
@@ -1,19 +0,0 @@
|
|
1
|
-
module EasyML
|
2
|
-
module Models
|
3
|
-
module Hyperparameters
|
4
|
-
class XGBoost < Base
|
5
|
-
include GlueGun::DSL
|
6
|
-
|
7
|
-
attribute :learning_rate, :float, default: 0.1
|
8
|
-
attribute :max_depth, :integer, default: 6
|
9
|
-
attribute :n_estimators, :integer, default: 100
|
10
|
-
attribute :booster, :string, default: "gbtree"
|
11
|
-
attribute :objective, :string, default: "reg:squarederror"
|
12
|
-
|
13
|
-
validates :objective,
|
14
|
-
inclusion: { in: %w[binary:logistic binary:hinge multi:softmax multi:softprob reg:squarederror
|
15
|
-
reg:logistic] }
|
16
|
-
end
|
17
|
-
end
|
18
|
-
end
|
19
|
-
end
|
@@ -1,220 +0,0 @@
|
|
1
|
-
require "wandb"
|
2
|
-
module EasyML
|
3
|
-
module Core
|
4
|
-
module Models
|
5
|
-
module XGBoostCore
|
6
|
-
OBJECTIVES = {
|
7
|
-
classification: {
|
8
|
-
binary: %w[binary:logistic binary:hinge],
|
9
|
-
multi_class: %w[multi:softmax multi:softprob]
|
10
|
-
},
|
11
|
-
regression: %w[reg:squarederror reg:logistic]
|
12
|
-
}
|
13
|
-
|
14
|
-
def self.included(base)
|
15
|
-
base.class_eval do
|
16
|
-
attribute :evaluator
|
17
|
-
|
18
|
-
dependency :callbacks, { array: true } do |dep|
|
19
|
-
dep.option :wandb do |opt|
|
20
|
-
opt.set_class Wandb::XGBoostCallback
|
21
|
-
opt.bind_attribute :log_model, default: false
|
22
|
-
opt.bind_attribute :log_feature_importance, default: true
|
23
|
-
opt.bind_attribute :importance_type, default: "gain"
|
24
|
-
opt.bind_attribute :define_metric, default: true
|
25
|
-
opt.bind_attribute :project_name
|
26
|
-
end
|
27
|
-
end
|
28
|
-
|
29
|
-
dependency :hyperparameters do |dep|
|
30
|
-
dep.set_class EasyML::Models::Hyperparameters::XGBoost
|
31
|
-
dep.bind_attribute :batch_size, default: 32
|
32
|
-
dep.bind_attribute :learning_rate, default: 1.1
|
33
|
-
dep.bind_attribute :max_depth, default: 6
|
34
|
-
dep.bind_attribute :n_estimators, default: 100
|
35
|
-
dep.bind_attribute :booster, default: "gbtree"
|
36
|
-
dep.bind_attribute :objective, default: "reg:squarederror"
|
37
|
-
end
|
38
|
-
end
|
39
|
-
end
|
40
|
-
|
41
|
-
attr_accessor :model, :booster
|
42
|
-
|
43
|
-
def predict(xs)
|
44
|
-
raise "No trained model! Train a model before calling predict" unless @booster.present?
|
45
|
-
raise "Cannot predict on nil — XGBoost" if xs.nil?
|
46
|
-
|
47
|
-
y_pred = @booster.predict(preprocess(xs))
|
48
|
-
|
49
|
-
case task.to_sym
|
50
|
-
when :classification
|
51
|
-
to_classification(y_pred)
|
52
|
-
else
|
53
|
-
y_pred
|
54
|
-
end
|
55
|
-
end
|
56
|
-
|
57
|
-
def predict_proba(data)
|
58
|
-
dmat = DMatrix.new(data)
|
59
|
-
y_pred = @booster.predict(dmat)
|
60
|
-
|
61
|
-
if y_pred.first.is_a?(Array)
|
62
|
-
# multiple classes
|
63
|
-
y_pred
|
64
|
-
else
|
65
|
-
y_pred.map { |v| [1 - v, v] }
|
66
|
-
end
|
67
|
-
end
|
68
|
-
|
69
|
-
def load(path = nil)
|
70
|
-
path ||= file
|
71
|
-
path = path&.file&.file if path.class.ancestors.include?(CarrierWave::Uploader::Base)
|
72
|
-
|
73
|
-
raise "No existing model at #{path}" unless File.exist?(path)
|
74
|
-
|
75
|
-
initialize_model do
|
76
|
-
booster_class.new(params: hyperparameters.to_h, model_file: path)
|
77
|
-
end
|
78
|
-
end
|
79
|
-
|
80
|
-
def _save_model_file(path)
|
81
|
-
puts "XGBoost received path #{path}"
|
82
|
-
@booster.save_model(path)
|
83
|
-
end
|
84
|
-
|
85
|
-
def feature_importances
|
86
|
-
@model.booster.feature_names.zip(@model.feature_importances).to_h
|
87
|
-
end
|
88
|
-
|
89
|
-
def base_model
|
90
|
-
::XGBoost
|
91
|
-
end
|
92
|
-
|
93
|
-
def customize_callbacks
|
94
|
-
yield callbacks
|
95
|
-
end
|
96
|
-
|
97
|
-
private
|
98
|
-
|
99
|
-
def booster_class
|
100
|
-
::XGBoost::Booster
|
101
|
-
end
|
102
|
-
|
103
|
-
def d_matrix_class
|
104
|
-
::XGBoost::DMatrix
|
105
|
-
end
|
106
|
-
|
107
|
-
def model_class
|
108
|
-
::XGBoost::Model
|
109
|
-
end
|
110
|
-
|
111
|
-
def train
|
112
|
-
validate_objective
|
113
|
-
|
114
|
-
xs = xs.to_a.map(&:values)
|
115
|
-
ys = ys.to_a.map(&:values)
|
116
|
-
dtrain = d_matrix_class.new(xs, label: ys)
|
117
|
-
@model = base_model.train(hyperparameters.to_h, dtrain, callbacks: callbacks)
|
118
|
-
end
|
119
|
-
|
120
|
-
def train_in_batches
|
121
|
-
validate_objective
|
122
|
-
|
123
|
-
# Initialize the model with the first batch
|
124
|
-
@model = nil
|
125
|
-
@booster = nil
|
126
|
-
x_valid, y_valid = dataset.valid(split_ys: true)
|
127
|
-
x_train, y_train = dataset.train(split_ys: true)
|
128
|
-
fit_batch(x_train, y_train, x_valid, y_valid)
|
129
|
-
end
|
130
|
-
|
131
|
-
def _preprocess(df)
|
132
|
-
df.to_a.map do |row|
|
133
|
-
row.values.map do |value|
|
134
|
-
case value
|
135
|
-
when Time
|
136
|
-
value.to_i # Convert Time to Unix timestamp
|
137
|
-
when Date
|
138
|
-
value.to_time.to_i # Convert Date to Unix timestamp
|
139
|
-
when String
|
140
|
-
value
|
141
|
-
when TrueClass, FalseClass
|
142
|
-
value ? 1.0 : 0.0 # Convert booleans to 1.0 and 0.0
|
143
|
-
when Integer
|
144
|
-
value
|
145
|
-
else
|
146
|
-
value.to_f # Ensure everything else is converted to a float
|
147
|
-
end
|
148
|
-
end
|
149
|
-
end
|
150
|
-
end
|
151
|
-
|
152
|
-
def preprocess(xs, ys = nil)
|
153
|
-
column_names = xs.columns
|
154
|
-
xs = _preprocess(xs)
|
155
|
-
ys = ys.nil? ? nil : _preprocess(ys).flatten
|
156
|
-
kwargs = { label: ys }.compact
|
157
|
-
::XGBoost::DMatrix.new(xs, **kwargs).tap do |dmat|
|
158
|
-
dmat.feature_names = column_names
|
159
|
-
end
|
160
|
-
end
|
161
|
-
|
162
|
-
def initialize_model
|
163
|
-
@model = model_class.new(n_estimators: @hyperparameters.to_h.dig(:n_estimators))
|
164
|
-
@booster = yield
|
165
|
-
@model.instance_variable_set(:@booster, @booster)
|
166
|
-
end
|
167
|
-
|
168
|
-
def validate_objective
|
169
|
-
objective = hyperparameters.objective
|
170
|
-
unless task.present?
|
171
|
-
raise ArgumentError,
|
172
|
-
"cannot train model without task. Please specify either regression or classification (model.task = :regression)"
|
173
|
-
end
|
174
|
-
|
175
|
-
case task.to_sym
|
176
|
-
when :classification
|
177
|
-
_, ys = dataset.data(split_ys: true)
|
178
|
-
classification_type = ys[ys.columns.first].uniq.count <= 2 ? :binary : :multi_class
|
179
|
-
allowed_objectives = OBJECTIVES[:classification][classification_type]
|
180
|
-
else
|
181
|
-
allowed_objectives = OBJECTIVES[task.to_sym]
|
182
|
-
end
|
183
|
-
return if allowed_objectives.map(&:to_sym).include?(objective.to_sym)
|
184
|
-
|
185
|
-
raise ArgumentError,
|
186
|
-
"cannot use #{objective} for #{task} task. Allowed objectives are: #{allowed_objectives.join(", ")}"
|
187
|
-
end
|
188
|
-
|
189
|
-
def fit_batch(x_train, y_train, x_valid, y_valid)
|
190
|
-
d_train = preprocess(x_train, y_train)
|
191
|
-
d_valid = preprocess(x_valid, y_valid)
|
192
|
-
|
193
|
-
evals = [[d_train, "train"], [d_valid, "eval"]]
|
194
|
-
|
195
|
-
# # If this is the first batch, create the booster
|
196
|
-
if @booster.nil?
|
197
|
-
initialize_model do
|
198
|
-
base_model.train(@hyperparameters.to_h, d_train,
|
199
|
-
num_boost_round: @hyperparameters.to_h.dig("n_estimators"), evals: evals, callbacks: callbacks)
|
200
|
-
end
|
201
|
-
else
|
202
|
-
# Update the existing booster with the new batch
|
203
|
-
@model.update(d_train)
|
204
|
-
end
|
205
|
-
end
|
206
|
-
|
207
|
-
def to_classification(y_pred)
|
208
|
-
if y_pred.first.is_a?(Array)
|
209
|
-
# multiple classes
|
210
|
-
y_pred.map do |v|
|
211
|
-
v.map.with_index.max_by { |v2, _| v2 }.last
|
212
|
-
end
|
213
|
-
else
|
214
|
-
y_pred.map { |v| v > 0.5 ? 1 : 0 }
|
215
|
-
end
|
216
|
-
end
|
217
|
-
end
|
218
|
-
end
|
219
|
-
end
|
220
|
-
end
|
data/lib/easy_ml/core/models.rb
DELETED
@@ -1,24 +0,0 @@
|
|
1
|
-
require "carrierwave"
|
2
|
-
|
3
|
-
module EasyML
|
4
|
-
module Core
|
5
|
-
module Uploaders
|
6
|
-
class ModelUploader < CarrierWave::Uploader::Base
|
7
|
-
# Choose storage type
|
8
|
-
if Rails.env.production?
|
9
|
-
storage :fog
|
10
|
-
else
|
11
|
-
storage :file
|
12
|
-
end
|
13
|
-
|
14
|
-
def store_dir
|
15
|
-
"easy_ml_models/#{model.name}"
|
16
|
-
end
|
17
|
-
|
18
|
-
def extension_allowlist
|
19
|
-
%w[bin model json]
|
20
|
-
end
|
21
|
-
end
|
22
|
-
end
|
23
|
-
end
|
24
|
-
end
|
@@ -1,31 +0,0 @@
|
|
1
|
-
{
|
2
|
-
"annual_revenue": {
|
3
|
-
"median": {
|
4
|
-
"value": 3000.0,
|
5
|
-
"original_dtype": {
|
6
|
-
"__type__": "polars_dtype",
|
7
|
-
"value": "Polars::Int64"
|
8
|
-
}
|
9
|
-
}
|
10
|
-
},
|
11
|
-
"loan_purpose": {
|
12
|
-
"categorical": {
|
13
|
-
"value": {
|
14
|
-
"payroll": 4,
|
15
|
-
"expansion": 1
|
16
|
-
},
|
17
|
-
"label_encoder": {
|
18
|
-
"expansion": 0,
|
19
|
-
"payroll": 1
|
20
|
-
},
|
21
|
-
"label_decoder": {
|
22
|
-
"0": "expansion",
|
23
|
-
"1": "payroll"
|
24
|
-
},
|
25
|
-
"original_dtype": {
|
26
|
-
"__type__": "polars_dtype",
|
27
|
-
"value": "Polars::String"
|
28
|
-
}
|
29
|
-
}
|
30
|
-
}
|
31
|
-
}
|
@@ -1 +0,0 @@
|
|
1
|
-
{"previous_sample":1.0}
|
@@ -1 +0,0 @@
|
|
1
|
-
{"previous_sample":1.0}
|
@@ -1,140 +0,0 @@
|
|
1
|
-
require_relative "split"
|
2
|
-
|
3
|
-
module EasyML
|
4
|
-
module Data
|
5
|
-
class Dataset
|
6
|
-
module Splits
|
7
|
-
class FileSplit < Split
|
8
|
-
include GlueGun::DSL
|
9
|
-
include EasyML::Data::Utils
|
10
|
-
|
11
|
-
attribute :dir, :string
|
12
|
-
attribute :polars_args, :hash, default: {}
|
13
|
-
attribute :max_rows_per_file, :integer, default: 1_000_000
|
14
|
-
attribute :batch_size, :integer, default: 10_000
|
15
|
-
attribute :sample, :float, default: 1.0
|
16
|
-
attribute :verbose, :boolean, default: false
|
17
|
-
|
18
|
-
def initialize(options)
|
19
|
-
super
|
20
|
-
FileUtils.mkdir_p(dir)
|
21
|
-
end
|
22
|
-
|
23
|
-
def save(segment, df)
|
24
|
-
segment_dir = File.join(dir, segment.to_s)
|
25
|
-
FileUtils.mkdir_p(segment_dir)
|
26
|
-
|
27
|
-
current_file = current_file_for_segment(segment)
|
28
|
-
current_row_count = current_file && File.exist?(current_file) ? df(current_file).shape[0] : 0
|
29
|
-
remaining_rows = max_rows_per_file - current_row_count
|
30
|
-
|
31
|
-
while df.shape[0] > 0
|
32
|
-
if df.shape[0] <= remaining_rows
|
33
|
-
append_to_csv(df, current_file)
|
34
|
-
break
|
35
|
-
else
|
36
|
-
df_to_append = df.slice(0, remaining_rows)
|
37
|
-
df = df.slice(remaining_rows, df.shape[0] - remaining_rows)
|
38
|
-
append_to_csv(df_to_append, current_file)
|
39
|
-
current_file = new_file_path_for_segment(segment)
|
40
|
-
remaining_rows = max_rows_per_file
|
41
|
-
end
|
42
|
-
end
|
43
|
-
end
|
44
|
-
|
45
|
-
def read(segment, split_ys: false, target: nil, drop_cols: [], &block)
|
46
|
-
files = files_for_segment(segment)
|
47
|
-
|
48
|
-
if block_given?
|
49
|
-
result = nil
|
50
|
-
total_rows = files.sum { |file| df(file).shape[0] }
|
51
|
-
progress_bar = create_progress_bar(segment, total_rows) if verbose
|
52
|
-
|
53
|
-
files.each do |file|
|
54
|
-
df = self.df(file)
|
55
|
-
df = sample_data(df) if sample < 1.0
|
56
|
-
drop_cols &= df.columns
|
57
|
-
df = df.drop(drop_cols) unless drop_cols.empty?
|
58
|
-
|
59
|
-
if split_ys
|
60
|
-
xs, ys = split_features_targets(df, true, target)
|
61
|
-
result = process_block_with_split_ys(block, result, xs, ys)
|
62
|
-
else
|
63
|
-
result = process_block_without_split_ys(block, result, df)
|
64
|
-
end
|
65
|
-
|
66
|
-
progress_bar.progress += df.shape[0] if verbose
|
67
|
-
end
|
68
|
-
progress_bar.finish if verbose
|
69
|
-
result
|
70
|
-
elsif files.empty?
|
71
|
-
return nil, nil if split_ys
|
72
|
-
|
73
|
-
nil
|
74
|
-
|
75
|
-
else
|
76
|
-
combined_df = combine_dataframes(files)
|
77
|
-
combined_df = sample_data(combined_df) if sample < 1.0
|
78
|
-
drop_cols &= combined_df.columns
|
79
|
-
combined_df = combined_df.drop(drop_cols) unless drop_cols.empty?
|
80
|
-
split_features_targets(combined_df, split_ys, target)
|
81
|
-
end
|
82
|
-
end
|
83
|
-
|
84
|
-
def cleanup
|
85
|
-
FileUtils.rm_rf(dir)
|
86
|
-
FileUtils.mkdir_p(dir)
|
87
|
-
end
|
88
|
-
|
89
|
-
def split_at
|
90
|
-
return nil if output_files.empty?
|
91
|
-
|
92
|
-
output_files.map { |file| File.mtime(file) }.max
|
93
|
-
end
|
94
|
-
|
95
|
-
private
|
96
|
-
|
97
|
-
def read_csv_batched(path)
|
98
|
-
Polars.read_csv_batched(path, batch_size: batch_size, **polars_args)
|
99
|
-
end
|
100
|
-
|
101
|
-
def df(path)
|
102
|
-
Polars.read_csv(path, **polars_args)
|
103
|
-
end
|
104
|
-
|
105
|
-
def output_files
|
106
|
-
Dir.glob("#{dir}/**/*.csv")
|
107
|
-
end
|
108
|
-
|
109
|
-
def files_for_segment(segment)
|
110
|
-
segment_dir = File.join(dir, segment.to_s)
|
111
|
-
Dir.glob(File.join(segment_dir, "**/*.csv")).sort
|
112
|
-
end
|
113
|
-
|
114
|
-
def current_file_for_segment(segment)
|
115
|
-
current_file = files_for_segment(segment).last
|
116
|
-
return new_file_path_for_segment(segment) if current_file.nil?
|
117
|
-
|
118
|
-
row_count = df(current_file).shape[0]
|
119
|
-
if row_count >= max_rows_per_file
|
120
|
-
new_file_path_for_segment(segment)
|
121
|
-
else
|
122
|
-
current_file
|
123
|
-
end
|
124
|
-
end
|
125
|
-
|
126
|
-
def new_file_path_for_segment(segment)
|
127
|
-
segment_dir = File.join(dir, segment.to_s)
|
128
|
-
file_number = Dir.glob(File.join(segment_dir, "*.csv")).count
|
129
|
-
File.join(segment_dir, "#{segment}_%04d.csv" % file_number)
|
130
|
-
end
|
131
|
-
|
132
|
-
def combine_dataframes(files)
|
133
|
-
dfs = files.map { |file| df(file) }
|
134
|
-
Polars.concat(dfs)
|
135
|
-
end
|
136
|
-
end
|
137
|
-
end
|
138
|
-
end
|
139
|
-
end
|
140
|
-
end
|