easy_ml 0.1.4 → 0.2.0.pre.rc1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +234 -26
- data/Rakefile +45 -0
- data/app/controllers/easy_ml/application_controller.rb +67 -0
- data/app/controllers/easy_ml/columns_controller.rb +38 -0
- data/app/controllers/easy_ml/datasets_controller.rb +156 -0
- data/app/controllers/easy_ml/datasources_controller.rb +88 -0
- data/app/controllers/easy_ml/deploys_controller.rb +20 -0
- data/app/controllers/easy_ml/models_controller.rb +151 -0
- data/app/controllers/easy_ml/retraining_runs_controller.rb +19 -0
- data/app/controllers/easy_ml/settings_controller.rb +59 -0
- data/app/frontend/components/AlertProvider.tsx +108 -0
- data/app/frontend/components/DatasetPreview.tsx +161 -0
- data/app/frontend/components/EmptyState.tsx +28 -0
- data/app/frontend/components/ModelCard.tsx +255 -0
- data/app/frontend/components/ModelDetails.tsx +334 -0
- data/app/frontend/components/ModelForm.tsx +384 -0
- data/app/frontend/components/Navigation.tsx +300 -0
- data/app/frontend/components/Pagination.tsx +72 -0
- data/app/frontend/components/Popover.tsx +55 -0
- data/app/frontend/components/PredictionStream.tsx +105 -0
- data/app/frontend/components/ScheduleModal.tsx +726 -0
- data/app/frontend/components/SearchInput.tsx +23 -0
- data/app/frontend/components/SearchableSelect.tsx +132 -0
- data/app/frontend/components/dataset/AutosaveIndicator.tsx +39 -0
- data/app/frontend/components/dataset/ColumnConfigModal.tsx +431 -0
- data/app/frontend/components/dataset/ColumnFilters.tsx +256 -0
- data/app/frontend/components/dataset/ColumnList.tsx +101 -0
- data/app/frontend/components/dataset/FeatureConfigPopover.tsx +57 -0
- data/app/frontend/components/dataset/FeaturePicker.tsx +205 -0
- data/app/frontend/components/dataset/PreprocessingConfig.tsx +704 -0
- data/app/frontend/components/dataset/SplitConfigurator.tsx +120 -0
- data/app/frontend/components/dataset/splitters/DateSplitter.tsx +58 -0
- data/app/frontend/components/dataset/splitters/KFoldSplitter.tsx +68 -0
- data/app/frontend/components/dataset/splitters/LeavePOutSplitter.tsx +29 -0
- data/app/frontend/components/dataset/splitters/PredefinedSplitter.tsx +146 -0
- data/app/frontend/components/dataset/splitters/RandomSplitter.tsx +85 -0
- data/app/frontend/components/dataset/splitters/StratifiedSplitter.tsx +79 -0
- data/app/frontend/components/dataset/splitters/constants.ts +77 -0
- data/app/frontend/components/dataset/splitters/types.ts +168 -0
- data/app/frontend/components/dataset/splitters/utils.ts +53 -0
- data/app/frontend/components/features/CodeEditor.tsx +46 -0
- data/app/frontend/components/features/DataPreview.tsx +150 -0
- data/app/frontend/components/features/FeatureCard.tsx +88 -0
- data/app/frontend/components/features/FeatureForm.tsx +235 -0
- data/app/frontend/components/features/FeatureGroupCard.tsx +54 -0
- data/app/frontend/components/settings/PluginSettings.tsx +81 -0
- data/app/frontend/components/ui/badge.tsx +44 -0
- data/app/frontend/components/ui/collapsible.tsx +9 -0
- data/app/frontend/components/ui/scroll-area.tsx +46 -0
- data/app/frontend/components/ui/separator.tsx +29 -0
- data/app/frontend/entrypoints/App.tsx +40 -0
- data/app/frontend/entrypoints/Application.tsx +24 -0
- data/app/frontend/hooks/useAutosave.ts +61 -0
- data/app/frontend/layouts/Layout.tsx +38 -0
- data/app/frontend/lib/utils.ts +6 -0
- data/app/frontend/mockData.ts +272 -0
- data/app/frontend/pages/DatasetDetailsPage.tsx +103 -0
- data/app/frontend/pages/DatasetsPage.tsx +261 -0
- data/app/frontend/pages/DatasourceFormPage.tsx +147 -0
- data/app/frontend/pages/DatasourcesPage.tsx +261 -0
- data/app/frontend/pages/EditModelPage.tsx +45 -0
- data/app/frontend/pages/EditTransformationPage.tsx +56 -0
- data/app/frontend/pages/ModelsPage.tsx +115 -0
- data/app/frontend/pages/NewDatasetPage.tsx +366 -0
- data/app/frontend/pages/NewModelPage.tsx +45 -0
- data/app/frontend/pages/NewTransformationPage.tsx +43 -0
- data/app/frontend/pages/SettingsPage.tsx +272 -0
- data/app/frontend/pages/ShowModelPage.tsx +30 -0
- data/app/frontend/pages/TransformationsPage.tsx +95 -0
- data/app/frontend/styles/application.css +100 -0
- data/app/frontend/types/dataset.ts +146 -0
- data/app/frontend/types/datasource.ts +33 -0
- data/app/frontend/types/preprocessing.ts +1 -0
- data/app/frontend/types.ts +113 -0
- data/app/helpers/easy_ml/application_helper.rb +10 -0
- data/app/jobs/easy_ml/application_job.rb +21 -0
- data/app/jobs/easy_ml/batch_job.rb +46 -0
- data/app/jobs/easy_ml/compute_feature_job.rb +19 -0
- data/app/jobs/easy_ml/deploy_job.rb +13 -0
- data/app/jobs/easy_ml/finalize_feature_job.rb +15 -0
- data/app/jobs/easy_ml/refresh_dataset_job.rb +32 -0
- data/app/jobs/easy_ml/schedule_retraining_job.rb +11 -0
- data/app/jobs/easy_ml/sync_datasource_job.rb +17 -0
- data/app/jobs/easy_ml/training_job.rb +62 -0
- data/app/models/easy_ml/adapters/base_adapter.rb +45 -0
- data/app/models/easy_ml/adapters/polars_adapter.rb +77 -0
- data/app/models/easy_ml/cleaner.rb +82 -0
- data/app/models/easy_ml/column.rb +124 -0
- data/app/models/easy_ml/column_history.rb +30 -0
- data/app/models/easy_ml/column_list.rb +122 -0
- data/app/models/easy_ml/concerns/configurable.rb +61 -0
- data/app/models/easy_ml/concerns/versionable.rb +19 -0
- data/app/models/easy_ml/dataset.rb +767 -0
- data/app/models/easy_ml/dataset_history.rb +56 -0
- data/app/models/easy_ml/datasource.rb +182 -0
- data/app/models/easy_ml/datasource_history.rb +24 -0
- data/app/models/easy_ml/datasources/base_datasource.rb +54 -0
- data/app/models/easy_ml/datasources/file_datasource.rb +58 -0
- data/app/models/easy_ml/datasources/polars_datasource.rb +89 -0
- data/app/models/easy_ml/datasources/s3_datasource.rb +97 -0
- data/app/models/easy_ml/deploy.rb +114 -0
- data/app/models/easy_ml/event.rb +79 -0
- data/app/models/easy_ml/feature.rb +437 -0
- data/app/models/easy_ml/feature_history.rb +38 -0
- data/app/models/easy_ml/model.rb +575 -41
- data/app/models/easy_ml/model_file.rb +133 -0
- data/app/models/easy_ml/model_file_history.rb +24 -0
- data/app/models/easy_ml/model_history.rb +51 -0
- data/app/models/easy_ml/models/base_model.rb +58 -0
- data/app/models/easy_ml/models/hyperparameters/base.rb +99 -0
- data/app/models/easy_ml/models/hyperparameters/xgboost/dart.rb +82 -0
- data/app/models/easy_ml/models/hyperparameters/xgboost/gblinear.rb +82 -0
- data/app/models/easy_ml/models/hyperparameters/xgboost/gbtree.rb +97 -0
- data/app/models/easy_ml/models/hyperparameters/xgboost.rb +71 -0
- data/app/models/easy_ml/models/xgboost/evals_callback.rb +138 -0
- data/app/models/easy_ml/models/xgboost/progress_callback.rb +39 -0
- data/app/models/easy_ml/models/xgboost.rb +544 -5
- data/app/models/easy_ml/prediction.rb +44 -0
- data/app/models/easy_ml/retraining_job.rb +278 -0
- data/app/models/easy_ml/retraining_run.rb +184 -0
- data/app/models/easy_ml/settings.rb +37 -0
- data/app/models/easy_ml/splitter.rb +90 -0
- data/app/models/easy_ml/splitters/base_splitter.rb +28 -0
- data/app/models/easy_ml/splitters/date_splitter.rb +91 -0
- data/app/models/easy_ml/splitters/predefined_splitter.rb +74 -0
- data/app/models/easy_ml/splitters/random_splitter.rb +82 -0
- data/app/models/easy_ml/tuner_job.rb +56 -0
- data/app/models/easy_ml/tuner_run.rb +31 -0
- data/app/models/splitter_history.rb +6 -0
- data/app/serializers/easy_ml/column_serializer.rb +27 -0
- data/app/serializers/easy_ml/dataset_serializer.rb +73 -0
- data/app/serializers/easy_ml/datasource_serializer.rb +64 -0
- data/app/serializers/easy_ml/feature_serializer.rb +27 -0
- data/app/serializers/easy_ml/model_serializer.rb +90 -0
- data/app/serializers/easy_ml/retraining_job_serializer.rb +22 -0
- data/app/serializers/easy_ml/retraining_run_serializer.rb +39 -0
- data/app/serializers/easy_ml/settings_serializer.rb +9 -0
- data/app/views/layouts/easy_ml/application.html.erb +15 -0
- data/config/initializers/resque.rb +3 -0
- data/config/resque-pool.yml +6 -0
- data/config/routes.rb +39 -0
- data/config/spring.rb +1 -0
- data/config/vite.json +15 -0
- data/lib/easy_ml/configuration.rb +64 -0
- data/lib/easy_ml/core/evaluators/base_evaluator.rb +53 -0
- data/lib/easy_ml/core/evaluators/classification_evaluators.rb +126 -0
- data/lib/easy_ml/core/evaluators/regression_evaluators.rb +66 -0
- data/lib/easy_ml/core/model_evaluator.rb +161 -89
- data/lib/easy_ml/core/tuner/adapters/base_adapter.rb +28 -18
- data/lib/easy_ml/core/tuner/adapters/xgboost_adapter.rb +4 -25
- data/lib/easy_ml/core/tuner.rb +123 -62
- data/lib/easy_ml/core.rb +0 -3
- data/lib/easy_ml/core_ext/hash.rb +24 -0
- data/lib/easy_ml/core_ext/pathname.rb +11 -5
- data/lib/easy_ml/data/date_converter.rb +90 -0
- data/lib/easy_ml/data/filter_extensions.rb +31 -0
- data/lib/easy_ml/data/polars_column.rb +126 -0
- data/lib/easy_ml/data/polars_reader.rb +297 -0
- data/lib/easy_ml/data/preprocessor.rb +280 -142
- data/lib/easy_ml/data/simple_imputer.rb +255 -0
- data/lib/easy_ml/data/splits/file_split.rb +252 -0
- data/lib/easy_ml/data/splits/in_memory_split.rb +54 -0
- data/lib/easy_ml/data/splits/split.rb +95 -0
- data/lib/easy_ml/data/splits.rb +9 -0
- data/lib/easy_ml/data/statistics_learner.rb +93 -0
- data/lib/easy_ml/data/synced_directory.rb +341 -0
- data/lib/easy_ml/data.rb +6 -2
- data/lib/easy_ml/engine.rb +105 -6
- data/lib/easy_ml/feature_store.rb +227 -0
- data/lib/easy_ml/features.rb +61 -0
- data/lib/easy_ml/initializers/inflections.rb +17 -3
- data/lib/easy_ml/logging.rb +2 -2
- data/lib/easy_ml/predict.rb +74 -0
- data/lib/easy_ml/railtie/generators/migration/migration_generator.rb +192 -36
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_column_histories.rb.tt +9 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_columns.rb.tt +25 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_dataset_histories.rb.tt +9 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_datasets.rb.tt +31 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_datasource_histories.rb.tt +9 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_datasources.rb.tt +16 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_deploys.rb.tt +24 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_events.rb.tt +20 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_feature_histories.rb.tt +14 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_features.rb.tt +32 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_model_file_histories.rb.tt +9 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_model_files.rb.tt +17 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_model_histories.rb.tt +9 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_models.rb.tt +20 -9
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_predictions.rb.tt +17 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_retraining_jobs.rb.tt +77 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_settings.rb.tt +9 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_splitter_histories.rb.tt +9 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_splitters.rb.tt +15 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_tuner_jobs.rb.tt +40 -0
- data/lib/easy_ml/support/est.rb +5 -1
- data/lib/easy_ml/support/file_rotate.rb +79 -15
- data/lib/easy_ml/support/file_support.rb +9 -0
- data/lib/easy_ml/support/local_file.rb +24 -0
- data/lib/easy_ml/support/lockable.rb +62 -0
- data/lib/easy_ml/support/synced_file.rb +103 -0
- data/lib/easy_ml/support/utc.rb +5 -1
- data/lib/easy_ml/support.rb +6 -3
- data/lib/easy_ml/version.rb +4 -1
- data/lib/easy_ml.rb +7 -2
- metadata +355 -72
- data/app/models/easy_ml/models.rb +0 -5
- data/lib/easy_ml/core/model.rb +0 -30
- data/lib/easy_ml/core/model_core.rb +0 -181
- data/lib/easy_ml/core/models/hyperparameters/base.rb +0 -34
- data/lib/easy_ml/core/models/hyperparameters/xgboost.rb +0 -19
- data/lib/easy_ml/core/models/xgboost.rb +0 -10
- data/lib/easy_ml/core/models/xgboost_core.rb +0 -220
- data/lib/easy_ml/core/models.rb +0 -10
- data/lib/easy_ml/core/uploaders/model_uploader.rb +0 -24
- data/lib/easy_ml/core/uploaders.rb +0 -7
- data/lib/easy_ml/data/dataloader.rb +0 -6
- data/lib/easy_ml/data/dataset/data/preprocessor/statistics.json +0 -31
- data/lib/easy_ml/data/dataset/data/sample_info.json +0 -1
- data/lib/easy_ml/data/dataset/dataset/files/sample_info.json +0 -1
- data/lib/easy_ml/data/dataset/splits/file_split.rb +0 -140
- data/lib/easy_ml/data/dataset/splits/in_memory_split.rb +0 -49
- data/lib/easy_ml/data/dataset/splits/split.rb +0 -98
- data/lib/easy_ml/data/dataset/splits.rb +0 -11
- data/lib/easy_ml/data/dataset/splitters/date_splitter.rb +0 -43
- data/lib/easy_ml/data/dataset/splitters.rb +0 -9
- data/lib/easy_ml/data/dataset.rb +0 -430
- data/lib/easy_ml/data/datasource/datasource_factory.rb +0 -60
- data/lib/easy_ml/data/datasource/file_datasource.rb +0 -40
- data/lib/easy_ml/data/datasource/merged_datasource.rb +0 -64
- data/lib/easy_ml/data/datasource/polars_datasource.rb +0 -41
- data/lib/easy_ml/data/datasource/s3_datasource.rb +0 -89
- data/lib/easy_ml/data/datasource.rb +0 -33
- data/lib/easy_ml/data/preprocessor/preprocessor.rb +0 -205
- data/lib/easy_ml/data/preprocessor/simple_imputer.rb +0 -402
- data/lib/easy_ml/deployment.rb +0 -5
- data/lib/easy_ml/support/synced_directory.rb +0 -134
- data/lib/easy_ml/transforms.rb +0 -29
- /data/{lib/easy_ml/core → app/models/easy_ml}/models/hyperparameters.rb +0 -0
@@ -1,181 +0,0 @@
|
|
1
|
-
require "carrierwave"
|
2
|
-
require_relative "uploaders/model_uploader"
|
3
|
-
|
4
|
-
module EasyML
|
5
|
-
module Core
|
6
|
-
module ModelCore
|
7
|
-
attr_accessor :dataset
|
8
|
-
|
9
|
-
def self.included(base)
|
10
|
-
base.send(:include, GlueGun::DSL)
|
11
|
-
base.send(:extend, CarrierWave::Mount)
|
12
|
-
base.send(:mount_uploader, :file, EasyML::Core::Uploaders::ModelUploader)
|
13
|
-
|
14
|
-
base.class_eval do
|
15
|
-
validates :task, inclusion: { in: %w[regression classification] }
|
16
|
-
validates :task, presence: true
|
17
|
-
validate :dataset_is_a_dataset?
|
18
|
-
validate :validate_any_metrics?
|
19
|
-
validate :validate_metrics_for_task
|
20
|
-
before_validation :save_model_file, if: -> { fit? }
|
21
|
-
end
|
22
|
-
end
|
23
|
-
|
24
|
-
def fit(x_train: nil, y_train: nil, x_valid: nil, y_valid: nil)
|
25
|
-
if x_train.nil?
|
26
|
-
dataset.refresh!
|
27
|
-
train_in_batches
|
28
|
-
else
|
29
|
-
train(x_train, y_train, x_valid, y_valid)
|
30
|
-
end
|
31
|
-
@is_fit = true
|
32
|
-
end
|
33
|
-
|
34
|
-
def predict(xs)
|
35
|
-
raise NotImplementedError, "Subclasses must implement predict method"
|
36
|
-
end
|
37
|
-
|
38
|
-
def load
|
39
|
-
raise NotImplementedError, "Subclasses must implement load method"
|
40
|
-
end
|
41
|
-
|
42
|
-
def _save_model_file
|
43
|
-
raise NotImplementedError, "Subclasses must implement _save_model_file method"
|
44
|
-
end
|
45
|
-
|
46
|
-
def save
|
47
|
-
super if defined?(super) && self.class.superclass.method_defined?(:save)
|
48
|
-
save_model_file
|
49
|
-
end
|
50
|
-
|
51
|
-
def decode_labels(ys, col: nil)
|
52
|
-
dataset.decode_labels(ys, col: col)
|
53
|
-
end
|
54
|
-
|
55
|
-
def evaluate(y_pred: nil, y_true: nil, x_true: nil, evaluator: nil)
|
56
|
-
evaluator ||= self.evaluator
|
57
|
-
EasyML::Core::ModelEvaluator.evaluate(model: self, y_pred: y_pred, y_true: y_true, x_true: x_true,
|
58
|
-
evaluator: evaluator)
|
59
|
-
end
|
60
|
-
|
61
|
-
def save_model_file
|
62
|
-
raise "No trained model! Need to train model before saving (call model.fit)" unless fit?
|
63
|
-
|
64
|
-
path = File.join(model_dir, "#{version}.json")
|
65
|
-
ensure_directory_exists(File.dirname(path))
|
66
|
-
|
67
|
-
_save_model_file(path)
|
68
|
-
|
69
|
-
File.open(path) do |f|
|
70
|
-
self.file = f
|
71
|
-
end
|
72
|
-
file.store!
|
73
|
-
|
74
|
-
cleanup
|
75
|
-
end
|
76
|
-
|
77
|
-
def get_params
|
78
|
-
@hyperparameters.to_h
|
79
|
-
end
|
80
|
-
|
81
|
-
def allowed_metrics
|
82
|
-
return [] unless task.present?
|
83
|
-
|
84
|
-
case task.to_sym
|
85
|
-
when :regression
|
86
|
-
%w[mean_absolute_error mean_squared_error root_mean_squared_error r2_score]
|
87
|
-
when :classification
|
88
|
-
%w[accuracy_score precision_score recall_score f1_score auc roc_auc]
|
89
|
-
else
|
90
|
-
[]
|
91
|
-
end
|
92
|
-
end
|
93
|
-
|
94
|
-
def cleanup!
|
95
|
-
[carrierwave_dir, model_dir].each do |dir|
|
96
|
-
EasyML::FileRotate.new(dir, []).cleanup(extension_allowlist)
|
97
|
-
end
|
98
|
-
end
|
99
|
-
|
100
|
-
def cleanup
|
101
|
-
[carrierwave_dir, model_dir].each do |dir|
|
102
|
-
EasyML::FileRotate.new(dir, files_to_keep).cleanup(extension_allowlist)
|
103
|
-
end
|
104
|
-
end
|
105
|
-
|
106
|
-
def fit?
|
107
|
-
@is_fit == true
|
108
|
-
end
|
109
|
-
|
110
|
-
private
|
111
|
-
|
112
|
-
def carrierwave_dir
|
113
|
-
return unless file.path.present?
|
114
|
-
|
115
|
-
File.dirname(file.path).split("/")[0..-2].join("/")
|
116
|
-
end
|
117
|
-
|
118
|
-
def extension_allowlist
|
119
|
-
EasyML::Core::Uploaders::ModelUploader.new.extension_allowlist
|
120
|
-
end
|
121
|
-
|
122
|
-
def _save_model_file(path = nil)
|
123
|
-
raise NotImplementedError, "Subclasses must implement _save_model_file method"
|
124
|
-
end
|
125
|
-
|
126
|
-
def ensure_directory_exists(dir)
|
127
|
-
FileUtils.mkdir_p(dir) unless File.directory?(dir)
|
128
|
-
end
|
129
|
-
|
130
|
-
def apply_defaults
|
131
|
-
self.version ||= generate_version_string
|
132
|
-
self.metrics ||= allowed_metrics
|
133
|
-
self.ml_model ||= get_ml_model
|
134
|
-
end
|
135
|
-
|
136
|
-
def get_ml_model
|
137
|
-
self.class.name.split("::").last.underscore
|
138
|
-
end
|
139
|
-
|
140
|
-
def generate_version_string
|
141
|
-
timestamp = Time.now.utc.strftime("%Y%m%d%H%M%S")
|
142
|
-
model_name = self.class.name.split("::").last.underscore
|
143
|
-
"#{model_name}_#{timestamp}"
|
144
|
-
end
|
145
|
-
|
146
|
-
def model_dir
|
147
|
-
File.join(root_dir, "easy_ml_models", name.present? ? name.split.join.underscore : "")
|
148
|
-
end
|
149
|
-
|
150
|
-
def files_to_keep
|
151
|
-
Dir.glob(File.join(carrierwave_dir, "**/*")).select { |f| File.file?(f) }.sort_by do |filename|
|
152
|
-
Time.parse(filename.split("/").last.gsub(/\D/, ""))
|
153
|
-
end.reverse.take(5)
|
154
|
-
end
|
155
|
-
|
156
|
-
def dataset_is_a_dataset?
|
157
|
-
return if dataset.nil?
|
158
|
-
return if dataset.class.ancestors.include?(EasyML::Data::Dataset)
|
159
|
-
|
160
|
-
errors.add(:dataset, "Must be a subclass of EasyML::Dataset")
|
161
|
-
end
|
162
|
-
|
163
|
-
def validate_any_metrics?
|
164
|
-
return if metrics.any?
|
165
|
-
|
166
|
-
errors.add(:metrics, "Must include at least one metric. Allowed metrics are #{allowed_metrics.join(", ")}")
|
167
|
-
end
|
168
|
-
|
169
|
-
def validate_metrics_for_task
|
170
|
-
nonsensical_metrics = metrics.select do |metric|
|
171
|
-
allowed_metrics.exclude?(metric)
|
172
|
-
end
|
173
|
-
|
174
|
-
return unless nonsensical_metrics.any?
|
175
|
-
|
176
|
-
errors.add(:metrics,
|
177
|
-
"cannot use metrics: #{nonsensical_metrics.join(", ")} for task #{task}. Allowed metrics are: #{allowed_metrics.join(", ")}")
|
178
|
-
end
|
179
|
-
end
|
180
|
-
end
|
181
|
-
end
|
@@ -1,34 +0,0 @@
|
|
1
|
-
module EasyML
|
2
|
-
module Models
|
3
|
-
module Hyperparameters
|
4
|
-
class Base
|
5
|
-
include GlueGun::DSL
|
6
|
-
|
7
|
-
attribute :learning_rate, :float, default: 0.01
|
8
|
-
attribute :max_iterations, :integer, default: 100
|
9
|
-
attribute :batch_size, :integer, default: 32
|
10
|
-
attribute :regularization, :float, default: 0.0001
|
11
|
-
|
12
|
-
def to_h
|
13
|
-
attributes
|
14
|
-
end
|
15
|
-
|
16
|
-
def merge(other)
|
17
|
-
return self if other.nil?
|
18
|
-
|
19
|
-
other_hash = other.is_a?(Hyperparameters) ? other.to_h : other
|
20
|
-
merged_hash = to_h.merge(other_hash)
|
21
|
-
self.class.new(**merged_hash)
|
22
|
-
end
|
23
|
-
|
24
|
-
def [](key)
|
25
|
-
send(key) if respond_to?(key)
|
26
|
-
end
|
27
|
-
|
28
|
-
def []=(key, value)
|
29
|
-
send("#{key}=", value) if respond_to?("#{key}=")
|
30
|
-
end
|
31
|
-
end
|
32
|
-
end
|
33
|
-
end
|
34
|
-
end
|
@@ -1,19 +0,0 @@
|
|
1
|
-
module EasyML
|
2
|
-
module Models
|
3
|
-
module Hyperparameters
|
4
|
-
class XGBoost < Base
|
5
|
-
include GlueGun::DSL
|
6
|
-
|
7
|
-
attribute :learning_rate, :float, default: 0.1
|
8
|
-
attribute :max_depth, :integer, default: 6
|
9
|
-
attribute :n_estimators, :integer, default: 100
|
10
|
-
attribute :booster, :string, default: "gbtree"
|
11
|
-
attribute :objective, :string, default: "reg:squarederror"
|
12
|
-
|
13
|
-
validates :objective,
|
14
|
-
inclusion: { in: %w[binary:logistic binary:hinge multi:softmax multi:softprob reg:squarederror
|
15
|
-
reg:logistic] }
|
16
|
-
end
|
17
|
-
end
|
18
|
-
end
|
19
|
-
end
|
@@ -1,220 +0,0 @@
|
|
1
|
-
require "wandb"
|
2
|
-
module EasyML
|
3
|
-
module Core
|
4
|
-
module Models
|
5
|
-
module XGBoostCore
|
6
|
-
OBJECTIVES = {
|
7
|
-
classification: {
|
8
|
-
binary: %w[binary:logistic binary:hinge],
|
9
|
-
multi_class: %w[multi:softmax multi:softprob]
|
10
|
-
},
|
11
|
-
regression: %w[reg:squarederror reg:logistic]
|
12
|
-
}
|
13
|
-
|
14
|
-
def self.included(base)
|
15
|
-
base.class_eval do
|
16
|
-
attribute :evaluator
|
17
|
-
|
18
|
-
dependency :callbacks, { array: true } do |dep|
|
19
|
-
dep.option :wandb do |opt|
|
20
|
-
opt.set_class Wandb::XGBoostCallback
|
21
|
-
opt.bind_attribute :log_model, default: false
|
22
|
-
opt.bind_attribute :log_feature_importance, default: true
|
23
|
-
opt.bind_attribute :importance_type, default: "gain"
|
24
|
-
opt.bind_attribute :define_metric, default: true
|
25
|
-
opt.bind_attribute :project_name
|
26
|
-
end
|
27
|
-
end
|
28
|
-
|
29
|
-
dependency :hyperparameters do |dep|
|
30
|
-
dep.set_class EasyML::Models::Hyperparameters::XGBoost
|
31
|
-
dep.bind_attribute :batch_size, default: 32
|
32
|
-
dep.bind_attribute :learning_rate, default: 1.1
|
33
|
-
dep.bind_attribute :max_depth, default: 6
|
34
|
-
dep.bind_attribute :n_estimators, default: 100
|
35
|
-
dep.bind_attribute :booster, default: "gbtree"
|
36
|
-
dep.bind_attribute :objective, default: "reg:squarederror"
|
37
|
-
end
|
38
|
-
end
|
39
|
-
end
|
40
|
-
|
41
|
-
attr_accessor :model, :booster
|
42
|
-
|
43
|
-
def predict(xs)
|
44
|
-
raise "No trained model! Train a model before calling predict" unless @booster.present?
|
45
|
-
raise "Cannot predict on nil — XGBoost" if xs.nil?
|
46
|
-
|
47
|
-
y_pred = @booster.predict(preprocess(xs))
|
48
|
-
|
49
|
-
case task.to_sym
|
50
|
-
when :classification
|
51
|
-
to_classification(y_pred)
|
52
|
-
else
|
53
|
-
y_pred
|
54
|
-
end
|
55
|
-
end
|
56
|
-
|
57
|
-
def predict_proba(data)
|
58
|
-
dmat = DMatrix.new(data)
|
59
|
-
y_pred = @booster.predict(dmat)
|
60
|
-
|
61
|
-
if y_pred.first.is_a?(Array)
|
62
|
-
# multiple classes
|
63
|
-
y_pred
|
64
|
-
else
|
65
|
-
y_pred.map { |v| [1 - v, v] }
|
66
|
-
end
|
67
|
-
end
|
68
|
-
|
69
|
-
def load(path = nil)
|
70
|
-
path ||= file
|
71
|
-
path = path&.file&.file if path.class.ancestors.include?(CarrierWave::Uploader::Base)
|
72
|
-
|
73
|
-
raise "No existing model at #{path}" unless File.exist?(path)
|
74
|
-
|
75
|
-
initialize_model do
|
76
|
-
booster_class.new(params: hyperparameters.to_h, model_file: path)
|
77
|
-
end
|
78
|
-
end
|
79
|
-
|
80
|
-
def _save_model_file(path)
|
81
|
-
puts "XGBoost received path #{path}"
|
82
|
-
@booster.save_model(path)
|
83
|
-
end
|
84
|
-
|
85
|
-
def feature_importances
|
86
|
-
@model.booster.feature_names.zip(@model.feature_importances).to_h
|
87
|
-
end
|
88
|
-
|
89
|
-
def base_model
|
90
|
-
::XGBoost
|
91
|
-
end
|
92
|
-
|
93
|
-
def customize_callbacks
|
94
|
-
yield callbacks
|
95
|
-
end
|
96
|
-
|
97
|
-
private
|
98
|
-
|
99
|
-
def booster_class
|
100
|
-
::XGBoost::Booster
|
101
|
-
end
|
102
|
-
|
103
|
-
def d_matrix_class
|
104
|
-
::XGBoost::DMatrix
|
105
|
-
end
|
106
|
-
|
107
|
-
def model_class
|
108
|
-
::XGBoost::Model
|
109
|
-
end
|
110
|
-
|
111
|
-
def train
|
112
|
-
validate_objective
|
113
|
-
|
114
|
-
xs = xs.to_a.map(&:values)
|
115
|
-
ys = ys.to_a.map(&:values)
|
116
|
-
dtrain = d_matrix_class.new(xs, label: ys)
|
117
|
-
@model = base_model.train(hyperparameters.to_h, dtrain, callbacks: callbacks)
|
118
|
-
end
|
119
|
-
|
120
|
-
def train_in_batches
|
121
|
-
validate_objective
|
122
|
-
|
123
|
-
# Initialize the model with the first batch
|
124
|
-
@model = nil
|
125
|
-
@booster = nil
|
126
|
-
x_valid, y_valid = dataset.valid(split_ys: true)
|
127
|
-
x_train, y_train = dataset.train(split_ys: true)
|
128
|
-
fit_batch(x_train, y_train, x_valid, y_valid)
|
129
|
-
end
|
130
|
-
|
131
|
-
def _preprocess(df)
|
132
|
-
df.to_a.map do |row|
|
133
|
-
row.values.map do |value|
|
134
|
-
case value
|
135
|
-
when Time
|
136
|
-
value.to_i # Convert Time to Unix timestamp
|
137
|
-
when Date
|
138
|
-
value.to_time.to_i # Convert Date to Unix timestamp
|
139
|
-
when String
|
140
|
-
value
|
141
|
-
when TrueClass, FalseClass
|
142
|
-
value ? 1.0 : 0.0 # Convert booleans to 1.0 and 0.0
|
143
|
-
when Integer
|
144
|
-
value
|
145
|
-
else
|
146
|
-
value.to_f # Ensure everything else is converted to a float
|
147
|
-
end
|
148
|
-
end
|
149
|
-
end
|
150
|
-
end
|
151
|
-
|
152
|
-
def preprocess(xs, ys = nil)
|
153
|
-
column_names = xs.columns
|
154
|
-
xs = _preprocess(xs)
|
155
|
-
ys = ys.nil? ? nil : _preprocess(ys).flatten
|
156
|
-
kwargs = { label: ys }.compact
|
157
|
-
::XGBoost::DMatrix.new(xs, **kwargs).tap do |dmat|
|
158
|
-
dmat.feature_names = column_names
|
159
|
-
end
|
160
|
-
end
|
161
|
-
|
162
|
-
def initialize_model
|
163
|
-
@model = model_class.new(n_estimators: @hyperparameters.to_h.dig(:n_estimators))
|
164
|
-
@booster = yield
|
165
|
-
@model.instance_variable_set(:@booster, @booster)
|
166
|
-
end
|
167
|
-
|
168
|
-
def validate_objective
|
169
|
-
objective = hyperparameters.objective
|
170
|
-
unless task.present?
|
171
|
-
raise ArgumentError,
|
172
|
-
"cannot train model without task. Please specify either regression or classification (model.task = :regression)"
|
173
|
-
end
|
174
|
-
|
175
|
-
case task.to_sym
|
176
|
-
when :classification
|
177
|
-
_, ys = dataset.data(split_ys: true)
|
178
|
-
classification_type = ys[ys.columns.first].uniq.count <= 2 ? :binary : :multi_class
|
179
|
-
allowed_objectives = OBJECTIVES[:classification][classification_type]
|
180
|
-
else
|
181
|
-
allowed_objectives = OBJECTIVES[task.to_sym]
|
182
|
-
end
|
183
|
-
return if allowed_objectives.map(&:to_sym).include?(objective.to_sym)
|
184
|
-
|
185
|
-
raise ArgumentError,
|
186
|
-
"cannot use #{objective} for #{task} task. Allowed objectives are: #{allowed_objectives.join(", ")}"
|
187
|
-
end
|
188
|
-
|
189
|
-
def fit_batch(x_train, y_train, x_valid, y_valid)
|
190
|
-
d_train = preprocess(x_train, y_train)
|
191
|
-
d_valid = preprocess(x_valid, y_valid)
|
192
|
-
|
193
|
-
evals = [[d_train, "train"], [d_valid, "eval"]]
|
194
|
-
|
195
|
-
# # If this is the first batch, create the booster
|
196
|
-
if @booster.nil?
|
197
|
-
initialize_model do
|
198
|
-
base_model.train(@hyperparameters.to_h, d_train,
|
199
|
-
num_boost_round: @hyperparameters.to_h.dig("n_estimators"), evals: evals, callbacks: callbacks)
|
200
|
-
end
|
201
|
-
else
|
202
|
-
# Update the existing booster with the new batch
|
203
|
-
@model.update(d_train)
|
204
|
-
end
|
205
|
-
end
|
206
|
-
|
207
|
-
def to_classification(y_pred)
|
208
|
-
if y_pred.first.is_a?(Array)
|
209
|
-
# multiple classes
|
210
|
-
y_pred.map do |v|
|
211
|
-
v.map.with_index.max_by { |v2, _| v2 }.last
|
212
|
-
end
|
213
|
-
else
|
214
|
-
y_pred.map { |v| v > 0.5 ? 1 : 0 }
|
215
|
-
end
|
216
|
-
end
|
217
|
-
end
|
218
|
-
end
|
219
|
-
end
|
220
|
-
end
|
data/lib/easy_ml/core/models.rb
DELETED
@@ -1,24 +0,0 @@
|
|
1
|
-
require "carrierwave"
|
2
|
-
|
3
|
-
module EasyML
|
4
|
-
module Core
|
5
|
-
module Uploaders
|
6
|
-
class ModelUploader < CarrierWave::Uploader::Base
|
7
|
-
# Choose storage type
|
8
|
-
if Rails.env.production?
|
9
|
-
storage :fog
|
10
|
-
else
|
11
|
-
storage :file
|
12
|
-
end
|
13
|
-
|
14
|
-
def store_dir
|
15
|
-
"easy_ml_models/#{model.name}"
|
16
|
-
end
|
17
|
-
|
18
|
-
def extension_allowlist
|
19
|
-
%w[bin model json]
|
20
|
-
end
|
21
|
-
end
|
22
|
-
end
|
23
|
-
end
|
24
|
-
end
|
@@ -1,31 +0,0 @@
|
|
1
|
-
{
|
2
|
-
"annual_revenue": {
|
3
|
-
"median": {
|
4
|
-
"value": 3000.0,
|
5
|
-
"original_dtype": {
|
6
|
-
"__type__": "polars_dtype",
|
7
|
-
"value": "Polars::Int64"
|
8
|
-
}
|
9
|
-
}
|
10
|
-
},
|
11
|
-
"loan_purpose": {
|
12
|
-
"categorical": {
|
13
|
-
"value": {
|
14
|
-
"payroll": 4,
|
15
|
-
"expansion": 1
|
16
|
-
},
|
17
|
-
"label_encoder": {
|
18
|
-
"expansion": 0,
|
19
|
-
"payroll": 1
|
20
|
-
},
|
21
|
-
"label_decoder": {
|
22
|
-
"0": "expansion",
|
23
|
-
"1": "payroll"
|
24
|
-
},
|
25
|
-
"original_dtype": {
|
26
|
-
"__type__": "polars_dtype",
|
27
|
-
"value": "Polars::String"
|
28
|
-
}
|
29
|
-
}
|
30
|
-
}
|
31
|
-
}
|
@@ -1 +0,0 @@
|
|
1
|
-
{"previous_sample":1.0}
|
@@ -1 +0,0 @@
|
|
1
|
-
{"previous_sample":1.0}
|
@@ -1,140 +0,0 @@
|
|
1
|
-
require_relative "split"
|
2
|
-
|
3
|
-
module EasyML
|
4
|
-
module Data
|
5
|
-
class Dataset
|
6
|
-
module Splits
|
7
|
-
class FileSplit < Split
|
8
|
-
include GlueGun::DSL
|
9
|
-
include EasyML::Data::Utils
|
10
|
-
|
11
|
-
attribute :dir, :string
|
12
|
-
attribute :polars_args, :hash, default: {}
|
13
|
-
attribute :max_rows_per_file, :integer, default: 1_000_000
|
14
|
-
attribute :batch_size, :integer, default: 10_000
|
15
|
-
attribute :sample, :float, default: 1.0
|
16
|
-
attribute :verbose, :boolean, default: false
|
17
|
-
|
18
|
-
def initialize(options)
|
19
|
-
super
|
20
|
-
FileUtils.mkdir_p(dir)
|
21
|
-
end
|
22
|
-
|
23
|
-
def save(segment, df)
|
24
|
-
segment_dir = File.join(dir, segment.to_s)
|
25
|
-
FileUtils.mkdir_p(segment_dir)
|
26
|
-
|
27
|
-
current_file = current_file_for_segment(segment)
|
28
|
-
current_row_count = current_file && File.exist?(current_file) ? df(current_file).shape[0] : 0
|
29
|
-
remaining_rows = max_rows_per_file - current_row_count
|
30
|
-
|
31
|
-
while df.shape[0] > 0
|
32
|
-
if df.shape[0] <= remaining_rows
|
33
|
-
append_to_csv(df, current_file)
|
34
|
-
break
|
35
|
-
else
|
36
|
-
df_to_append = df.slice(0, remaining_rows)
|
37
|
-
df = df.slice(remaining_rows, df.shape[0] - remaining_rows)
|
38
|
-
append_to_csv(df_to_append, current_file)
|
39
|
-
current_file = new_file_path_for_segment(segment)
|
40
|
-
remaining_rows = max_rows_per_file
|
41
|
-
end
|
42
|
-
end
|
43
|
-
end
|
44
|
-
|
45
|
-
def read(segment, split_ys: false, target: nil, drop_cols: [], &block)
|
46
|
-
files = files_for_segment(segment)
|
47
|
-
|
48
|
-
if block_given?
|
49
|
-
result = nil
|
50
|
-
total_rows = files.sum { |file| df(file).shape[0] }
|
51
|
-
progress_bar = create_progress_bar(segment, total_rows) if verbose
|
52
|
-
|
53
|
-
files.each do |file|
|
54
|
-
df = self.df(file)
|
55
|
-
df = sample_data(df) if sample < 1.0
|
56
|
-
drop_cols &= df.columns
|
57
|
-
df = df.drop(drop_cols) unless drop_cols.empty?
|
58
|
-
|
59
|
-
if split_ys
|
60
|
-
xs, ys = split_features_targets(df, true, target)
|
61
|
-
result = process_block_with_split_ys(block, result, xs, ys)
|
62
|
-
else
|
63
|
-
result = process_block_without_split_ys(block, result, df)
|
64
|
-
end
|
65
|
-
|
66
|
-
progress_bar.progress += df.shape[0] if verbose
|
67
|
-
end
|
68
|
-
progress_bar.finish if verbose
|
69
|
-
result
|
70
|
-
elsif files.empty?
|
71
|
-
return nil, nil if split_ys
|
72
|
-
|
73
|
-
nil
|
74
|
-
|
75
|
-
else
|
76
|
-
combined_df = combine_dataframes(files)
|
77
|
-
combined_df = sample_data(combined_df) if sample < 1.0
|
78
|
-
drop_cols &= combined_df.columns
|
79
|
-
combined_df = combined_df.drop(drop_cols) unless drop_cols.empty?
|
80
|
-
split_features_targets(combined_df, split_ys, target)
|
81
|
-
end
|
82
|
-
end
|
83
|
-
|
84
|
-
def cleanup
|
85
|
-
FileUtils.rm_rf(dir)
|
86
|
-
FileUtils.mkdir_p(dir)
|
87
|
-
end
|
88
|
-
|
89
|
-
def split_at
|
90
|
-
return nil if output_files.empty?
|
91
|
-
|
92
|
-
output_files.map { |file| File.mtime(file) }.max
|
93
|
-
end
|
94
|
-
|
95
|
-
private
|
96
|
-
|
97
|
-
def read_csv_batched(path)
|
98
|
-
Polars.read_csv_batched(path, batch_size: batch_size, **polars_args)
|
99
|
-
end
|
100
|
-
|
101
|
-
def df(path)
|
102
|
-
Polars.read_csv(path, **polars_args)
|
103
|
-
end
|
104
|
-
|
105
|
-
def output_files
|
106
|
-
Dir.glob("#{dir}/**/*.csv")
|
107
|
-
end
|
108
|
-
|
109
|
-
def files_for_segment(segment)
|
110
|
-
segment_dir = File.join(dir, segment.to_s)
|
111
|
-
Dir.glob(File.join(segment_dir, "**/*.csv")).sort
|
112
|
-
end
|
113
|
-
|
114
|
-
def current_file_for_segment(segment)
|
115
|
-
current_file = files_for_segment(segment).last
|
116
|
-
return new_file_path_for_segment(segment) if current_file.nil?
|
117
|
-
|
118
|
-
row_count = df(current_file).shape[0]
|
119
|
-
if row_count >= max_rows_per_file
|
120
|
-
new_file_path_for_segment(segment)
|
121
|
-
else
|
122
|
-
current_file
|
123
|
-
end
|
124
|
-
end
|
125
|
-
|
126
|
-
def new_file_path_for_segment(segment)
|
127
|
-
segment_dir = File.join(dir, segment.to_s)
|
128
|
-
file_number = Dir.glob(File.join(segment_dir, "*.csv")).count
|
129
|
-
File.join(segment_dir, "#{segment}_%04d.csv" % file_number)
|
130
|
-
end
|
131
|
-
|
132
|
-
def combine_dataframes(files)
|
133
|
-
dfs = files.map { |file| df(file) }
|
134
|
-
Polars.concat(dfs)
|
135
|
-
end
|
136
|
-
end
|
137
|
-
end
|
138
|
-
end
|
139
|
-
end
|
140
|
-
end
|