easy_ml 0.1.3 → 0.2.0.pre.rc1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/README.md +234 -26
- data/Rakefile +45 -0
- data/app/controllers/easy_ml/application_controller.rb +67 -0
- data/app/controllers/easy_ml/columns_controller.rb +38 -0
- data/app/controllers/easy_ml/datasets_controller.rb +156 -0
- data/app/controllers/easy_ml/datasources_controller.rb +88 -0
- data/app/controllers/easy_ml/deploys_controller.rb +20 -0
- data/app/controllers/easy_ml/models_controller.rb +151 -0
- data/app/controllers/easy_ml/retraining_runs_controller.rb +19 -0
- data/app/controllers/easy_ml/settings_controller.rb +59 -0
- data/app/frontend/components/AlertProvider.tsx +108 -0
- data/app/frontend/components/DatasetPreview.tsx +161 -0
- data/app/frontend/components/EmptyState.tsx +28 -0
- data/app/frontend/components/ModelCard.tsx +255 -0
- data/app/frontend/components/ModelDetails.tsx +334 -0
- data/app/frontend/components/ModelForm.tsx +384 -0
- data/app/frontend/components/Navigation.tsx +300 -0
- data/app/frontend/components/Pagination.tsx +72 -0
- data/app/frontend/components/Popover.tsx +55 -0
- data/app/frontend/components/PredictionStream.tsx +105 -0
- data/app/frontend/components/ScheduleModal.tsx +726 -0
- data/app/frontend/components/SearchInput.tsx +23 -0
- data/app/frontend/components/SearchableSelect.tsx +132 -0
- data/app/frontend/components/dataset/AutosaveIndicator.tsx +39 -0
- data/app/frontend/components/dataset/ColumnConfigModal.tsx +431 -0
- data/app/frontend/components/dataset/ColumnFilters.tsx +256 -0
- data/app/frontend/components/dataset/ColumnList.tsx +101 -0
- data/app/frontend/components/dataset/FeatureConfigPopover.tsx +57 -0
- data/app/frontend/components/dataset/FeaturePicker.tsx +205 -0
- data/app/frontend/components/dataset/PreprocessingConfig.tsx +704 -0
- data/app/frontend/components/dataset/SplitConfigurator.tsx +120 -0
- data/app/frontend/components/dataset/splitters/DateSplitter.tsx +58 -0
- data/app/frontend/components/dataset/splitters/KFoldSplitter.tsx +68 -0
- data/app/frontend/components/dataset/splitters/LeavePOutSplitter.tsx +29 -0
- data/app/frontend/components/dataset/splitters/PredefinedSplitter.tsx +146 -0
- data/app/frontend/components/dataset/splitters/RandomSplitter.tsx +85 -0
- data/app/frontend/components/dataset/splitters/StratifiedSplitter.tsx +79 -0
- data/app/frontend/components/dataset/splitters/constants.ts +77 -0
- data/app/frontend/components/dataset/splitters/types.ts +168 -0
- data/app/frontend/components/dataset/splitters/utils.ts +53 -0
- data/app/frontend/components/features/CodeEditor.tsx +46 -0
- data/app/frontend/components/features/DataPreview.tsx +150 -0
- data/app/frontend/components/features/FeatureCard.tsx +88 -0
- data/app/frontend/components/features/FeatureForm.tsx +235 -0
- data/app/frontend/components/features/FeatureGroupCard.tsx +54 -0
- data/app/frontend/components/settings/PluginSettings.tsx +81 -0
- data/app/frontend/components/ui/badge.tsx +44 -0
- data/app/frontend/components/ui/collapsible.tsx +9 -0
- data/app/frontend/components/ui/scroll-area.tsx +46 -0
- data/app/frontend/components/ui/separator.tsx +29 -0
- data/app/frontend/entrypoints/App.tsx +40 -0
- data/app/frontend/entrypoints/Application.tsx +24 -0
- data/app/frontend/hooks/useAutosave.ts +61 -0
- data/app/frontend/layouts/Layout.tsx +38 -0
- data/app/frontend/lib/utils.ts +6 -0
- data/app/frontend/mockData.ts +272 -0
- data/app/frontend/pages/DatasetDetailsPage.tsx +103 -0
- data/app/frontend/pages/DatasetsPage.tsx +261 -0
- data/app/frontend/pages/DatasourceFormPage.tsx +147 -0
- data/app/frontend/pages/DatasourcesPage.tsx +261 -0
- data/app/frontend/pages/EditModelPage.tsx +45 -0
- data/app/frontend/pages/EditTransformationPage.tsx +56 -0
- data/app/frontend/pages/ModelsPage.tsx +115 -0
- data/app/frontend/pages/NewDatasetPage.tsx +366 -0
- data/app/frontend/pages/NewModelPage.tsx +45 -0
- data/app/frontend/pages/NewTransformationPage.tsx +43 -0
- data/app/frontend/pages/SettingsPage.tsx +272 -0
- data/app/frontend/pages/ShowModelPage.tsx +30 -0
- data/app/frontend/pages/TransformationsPage.tsx +95 -0
- data/app/frontend/styles/application.css +100 -0
- data/app/frontend/types/dataset.ts +146 -0
- data/app/frontend/types/datasource.ts +33 -0
- data/app/frontend/types/preprocessing.ts +1 -0
- data/app/frontend/types.ts +113 -0
- data/app/helpers/easy_ml/application_helper.rb +10 -0
- data/app/jobs/easy_ml/application_job.rb +21 -0
- data/app/jobs/easy_ml/batch_job.rb +46 -0
- data/app/jobs/easy_ml/compute_feature_job.rb +19 -0
- data/app/jobs/easy_ml/deploy_job.rb +13 -0
- data/app/jobs/easy_ml/finalize_feature_job.rb +15 -0
- data/app/jobs/easy_ml/refresh_dataset_job.rb +32 -0
- data/app/jobs/easy_ml/schedule_retraining_job.rb +11 -0
- data/app/jobs/easy_ml/sync_datasource_job.rb +17 -0
- data/app/jobs/easy_ml/training_job.rb +62 -0
- data/app/models/easy_ml/adapters/base_adapter.rb +45 -0
- data/app/models/easy_ml/adapters/polars_adapter.rb +77 -0
- data/app/models/easy_ml/cleaner.rb +82 -0
- data/app/models/easy_ml/column.rb +124 -0
- data/app/models/easy_ml/column_history.rb +30 -0
- data/app/models/easy_ml/column_list.rb +122 -0
- data/app/models/easy_ml/concerns/configurable.rb +61 -0
- data/app/models/easy_ml/concerns/versionable.rb +19 -0
- data/app/models/easy_ml/dataset.rb +767 -0
- data/app/models/easy_ml/dataset_history.rb +56 -0
- data/app/models/easy_ml/datasource.rb +182 -0
- data/app/models/easy_ml/datasource_history.rb +24 -0
- data/app/models/easy_ml/datasources/base_datasource.rb +54 -0
- data/app/models/easy_ml/datasources/file_datasource.rb +58 -0
- data/app/models/easy_ml/datasources/polars_datasource.rb +89 -0
- data/app/models/easy_ml/datasources/s3_datasource.rb +97 -0
- data/app/models/easy_ml/deploy.rb +114 -0
- data/app/models/easy_ml/event.rb +79 -0
- data/app/models/easy_ml/feature.rb +437 -0
- data/app/models/easy_ml/feature_history.rb +38 -0
- data/app/models/easy_ml/model.rb +575 -41
- data/app/models/easy_ml/model_file.rb +133 -0
- data/app/models/easy_ml/model_file_history.rb +24 -0
- data/app/models/easy_ml/model_history.rb +51 -0
- data/app/models/easy_ml/models/base_model.rb +58 -0
- data/app/models/easy_ml/models/hyperparameters/base.rb +99 -0
- data/app/models/easy_ml/models/hyperparameters/xgboost/dart.rb +82 -0
- data/app/models/easy_ml/models/hyperparameters/xgboost/gblinear.rb +82 -0
- data/app/models/easy_ml/models/hyperparameters/xgboost/gbtree.rb +97 -0
- data/app/models/easy_ml/models/hyperparameters/xgboost.rb +71 -0
- data/app/models/easy_ml/models/xgboost/evals_callback.rb +138 -0
- data/app/models/easy_ml/models/xgboost/progress_callback.rb +39 -0
- data/app/models/easy_ml/models/xgboost.rb +544 -4
- data/app/models/easy_ml/prediction.rb +44 -0
- data/app/models/easy_ml/retraining_job.rb +278 -0
- data/app/models/easy_ml/retraining_run.rb +184 -0
- data/app/models/easy_ml/settings.rb +37 -0
- data/app/models/easy_ml/splitter.rb +90 -0
- data/app/models/easy_ml/splitters/base_splitter.rb +28 -0
- data/app/models/easy_ml/splitters/date_splitter.rb +91 -0
- data/app/models/easy_ml/splitters/predefined_splitter.rb +74 -0
- data/app/models/easy_ml/splitters/random_splitter.rb +82 -0
- data/app/models/easy_ml/tuner_job.rb +56 -0
- data/app/models/easy_ml/tuner_run.rb +31 -0
- data/app/models/splitter_history.rb +6 -0
- data/app/serializers/easy_ml/column_serializer.rb +27 -0
- data/app/serializers/easy_ml/dataset_serializer.rb +73 -0
- data/app/serializers/easy_ml/datasource_serializer.rb +64 -0
- data/app/serializers/easy_ml/feature_serializer.rb +27 -0
- data/app/serializers/easy_ml/model_serializer.rb +90 -0
- data/app/serializers/easy_ml/retraining_job_serializer.rb +22 -0
- data/app/serializers/easy_ml/retraining_run_serializer.rb +39 -0
- data/app/serializers/easy_ml/settings_serializer.rb +9 -0
- data/app/views/layouts/easy_ml/application.html.erb +15 -0
- data/config/initializers/resque.rb +3 -0
- data/config/resque-pool.yml +6 -0
- data/config/routes.rb +39 -0
- data/config/spring.rb +1 -0
- data/config/vite.json +15 -0
- data/lib/easy_ml/configuration.rb +64 -0
- data/lib/easy_ml/core/evaluators/base_evaluator.rb +53 -0
- data/lib/easy_ml/core/evaluators/classification_evaluators.rb +126 -0
- data/lib/easy_ml/core/evaluators/regression_evaluators.rb +66 -0
- data/lib/easy_ml/core/model_evaluator.rb +161 -89
- data/lib/easy_ml/core/tuner/adapters/base_adapter.rb +28 -18
- data/lib/easy_ml/core/tuner/adapters/xgboost_adapter.rb +4 -25
- data/lib/easy_ml/core/tuner.rb +123 -62
- data/lib/easy_ml/core.rb +0 -3
- data/lib/easy_ml/core_ext/hash.rb +24 -0
- data/lib/easy_ml/core_ext/pathname.rb +11 -5
- data/lib/easy_ml/data/date_converter.rb +90 -0
- data/lib/easy_ml/data/filter_extensions.rb +31 -0
- data/lib/easy_ml/data/polars_column.rb +126 -0
- data/lib/easy_ml/data/polars_reader.rb +297 -0
- data/lib/easy_ml/data/preprocessor.rb +280 -142
- data/lib/easy_ml/data/simple_imputer.rb +255 -0
- data/lib/easy_ml/data/splits/file_split.rb +252 -0
- data/lib/easy_ml/data/splits/in_memory_split.rb +54 -0
- data/lib/easy_ml/data/splits/split.rb +95 -0
- data/lib/easy_ml/data/splits.rb +9 -0
- data/lib/easy_ml/data/statistics_learner.rb +93 -0
- data/lib/easy_ml/data/synced_directory.rb +341 -0
- data/lib/easy_ml/data.rb +6 -2
- data/lib/easy_ml/engine.rb +105 -6
- data/lib/easy_ml/feature_store.rb +227 -0
- data/lib/easy_ml/features.rb +61 -0
- data/lib/easy_ml/initializers/inflections.rb +17 -3
- data/lib/easy_ml/logging.rb +2 -2
- data/lib/easy_ml/predict.rb +74 -0
- data/lib/easy_ml/railtie/generators/migration/migration_generator.rb +192 -36
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_column_histories.rb.tt +9 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_columns.rb.tt +25 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_dataset_histories.rb.tt +9 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_datasets.rb.tt +31 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_datasource_histories.rb.tt +9 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_datasources.rb.tt +16 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_deploys.rb.tt +24 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_events.rb.tt +20 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_feature_histories.rb.tt +14 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_features.rb.tt +32 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_model_file_histories.rb.tt +9 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_model_files.rb.tt +17 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_model_histories.rb.tt +9 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_models.rb.tt +20 -9
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_predictions.rb.tt +17 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_retraining_jobs.rb.tt +77 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_settings.rb.tt +9 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_splitter_histories.rb.tt +9 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_splitters.rb.tt +15 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_tuner_jobs.rb.tt +40 -0
- data/lib/easy_ml/support/est.rb +5 -1
- data/lib/easy_ml/support/file_rotate.rb +79 -15
- data/lib/easy_ml/support/file_support.rb +9 -0
- data/lib/easy_ml/support/local_file.rb +24 -0
- data/lib/easy_ml/support/lockable.rb +62 -0
- data/lib/easy_ml/support/synced_file.rb +103 -0
- data/lib/easy_ml/support/utc.rb +5 -1
- data/lib/easy_ml/support.rb +6 -3
- data/lib/easy_ml/version.rb +4 -1
- data/lib/easy_ml.rb +7 -2
- metadata +355 -72
- data/app/models/easy_ml/models.rb +0 -5
- data/lib/easy_ml/core/model.rb +0 -30
- data/lib/easy_ml/core/model_core.rb +0 -181
- data/lib/easy_ml/core/models/hyperparameters/base.rb +0 -34
- data/lib/easy_ml/core/models/hyperparameters/xgboost.rb +0 -19
- data/lib/easy_ml/core/models/xgboost.rb +0 -10
- data/lib/easy_ml/core/models/xgboost_core.rb +0 -220
- data/lib/easy_ml/core/models.rb +0 -10
- data/lib/easy_ml/core/uploaders/model_uploader.rb +0 -24
- data/lib/easy_ml/core/uploaders.rb +0 -7
- data/lib/easy_ml/data/dataloader.rb +0 -6
- data/lib/easy_ml/data/dataset/data/preprocessor/statistics.json +0 -31
- data/lib/easy_ml/data/dataset/data/sample_info.json +0 -1
- data/lib/easy_ml/data/dataset/dataset/files/sample_info.json +0 -1
- data/lib/easy_ml/data/dataset/splits/file_split.rb +0 -140
- data/lib/easy_ml/data/dataset/splits/in_memory_split.rb +0 -49
- data/lib/easy_ml/data/dataset/splits/split.rb +0 -98
- data/lib/easy_ml/data/dataset/splits.rb +0 -11
- data/lib/easy_ml/data/dataset/splitters/date_splitter.rb +0 -43
- data/lib/easy_ml/data/dataset/splitters.rb +0 -9
- data/lib/easy_ml/data/dataset.rb +0 -430
- data/lib/easy_ml/data/datasource/datasource_factory.rb +0 -60
- data/lib/easy_ml/data/datasource/file_datasource.rb +0 -40
- data/lib/easy_ml/data/datasource/merged_datasource.rb +0 -64
- data/lib/easy_ml/data/datasource/polars_datasource.rb +0 -41
- data/lib/easy_ml/data/datasource/s3_datasource.rb +0 -89
- data/lib/easy_ml/data/datasource.rb +0 -33
- data/lib/easy_ml/data/preprocessor/preprocessor.rb +0 -205
- data/lib/easy_ml/data/preprocessor/simple_imputer.rb +0 -402
- data/lib/easy_ml/deployment.rb +0 -5
- data/lib/easy_ml/support/synced_directory.rb +0 -134
- data/lib/easy_ml/transforms.rb +0 -29
- /data/{lib/easy_ml/core → app/models/easy_ml}/models/hyperparameters.rb +0 -0
@@ -1,9 +1,549 @@
|
|
1
|
-
|
2
|
-
|
1
|
+
# == Schema Information
|
2
|
+
#
|
3
|
+
# Table name: easy_ml_models
|
4
|
+
#
|
5
|
+
# id :bigint not null, primary key
|
6
|
+
# name :string not null
|
7
|
+
# model_type :string
|
8
|
+
# status :string
|
9
|
+
# dataset_id :bigint
|
10
|
+
# configuration :json
|
11
|
+
# version :string not null
|
12
|
+
# root_dir :string
|
13
|
+
# file :json
|
14
|
+
# created_at :datetime not null
|
15
|
+
# updated_at :datetime not null
|
3
16
|
module EasyML
|
4
17
|
module Models
|
5
|
-
class XGBoost <
|
6
|
-
|
18
|
+
class XGBoost < BaseModel
|
19
|
+
Hyperparameters = EasyML::Models::Hyperparameters::XGBoost
|
20
|
+
|
21
|
+
OBJECTIVES = {
|
22
|
+
classification: {
|
23
|
+
binary: %w[binary:logistic binary:hinge],
|
24
|
+
multiclass: %w[multi:softmax multi:softprob],
|
25
|
+
},
|
26
|
+
regression: %w[reg:squarederror reg:logistic],
|
27
|
+
}
|
28
|
+
|
29
|
+
OBJECTIVES_FRONTEND = {
|
30
|
+
classification: [
|
31
|
+
{ value: "binary:logistic", label: "Binary Logistic", description: "For binary classification" },
|
32
|
+
{ value: "binary:hinge", label: "Binary Hinge", description: "For binary classification with hinge loss" },
|
33
|
+
{ value: "multi:softmax", label: "Multiclass Softmax", description: "For multiclass classification" },
|
34
|
+
{ value: "multi:softprob", label: "Multiclass Probability",
|
35
|
+
description: "For multiclass classification with probability output" },
|
36
|
+
],
|
37
|
+
regression: [
|
38
|
+
{ value: "reg:squarederror", label: "Squared Error", description: "For regression with squared loss" },
|
39
|
+
{ value: "reg:logistic", label: "Logistic", description: "For regression with logistic loss" },
|
40
|
+
],
|
41
|
+
}
|
42
|
+
|
43
|
+
add_configuration_attributes :early_stopping_rounds
|
44
|
+
attr_accessor :xgboost_model, :booster
|
45
|
+
|
46
|
+
def build_hyperparameters(params)
|
47
|
+
params = {} if params.nil?
|
48
|
+
return nil unless params.is_a?(Hash)
|
49
|
+
|
50
|
+
params.to_h.symbolize_keys!
|
51
|
+
|
52
|
+
params[:booster] = :gbtree unless params.key?(:booster)
|
53
|
+
|
54
|
+
klass = case params[:booster].to_sym
|
55
|
+
when :gbtree
|
56
|
+
Hyperparameters::GBTree
|
57
|
+
when :dart
|
58
|
+
Hyperparameters::Dart
|
59
|
+
when :gblinear
|
60
|
+
Hyperparameters::GBLinear
|
61
|
+
else
|
62
|
+
raise "Unknown booster type: #{booster}"
|
63
|
+
end
|
64
|
+
raise "Unknown booster type #{booster}" unless klass.present?
|
65
|
+
|
66
|
+
overrides = {
|
67
|
+
objective: model.objective,
|
68
|
+
}
|
69
|
+
params.merge!(overrides)
|
70
|
+
|
71
|
+
klass.new(params)
|
72
|
+
end
|
73
|
+
|
74
|
+
def add_auto_configurable_callbacks(params)
|
75
|
+
if EasyML::Configuration.wandb_api_key.present?
|
76
|
+
params.map!(&:deep_symbolize_keys)
|
77
|
+
unless params.any? { |c| c[:callback_type]&.to_sym == :wandb }
|
78
|
+
params << {
|
79
|
+
callback_type: :wandb,
|
80
|
+
project_name: model.name,
|
81
|
+
log_feature_importance: false,
|
82
|
+
define_metric: false,
|
83
|
+
}
|
84
|
+
end
|
85
|
+
|
86
|
+
unless params.any? { |c| c[:callback_type]&.to_sym == :evals_callback }
|
87
|
+
params << {
|
88
|
+
callback_type: :evals_callback,
|
89
|
+
}
|
90
|
+
end
|
91
|
+
|
92
|
+
unless params.any? { |c| c[:callback_type]&.to_sym == :progress_callback }
|
93
|
+
params << {
|
94
|
+
callback_type: :progress_callback,
|
95
|
+
}
|
96
|
+
end
|
97
|
+
|
98
|
+
params.sort_by! { |c| c[:callback_type] == :evals_callback ? 0 : 1 }
|
99
|
+
end
|
100
|
+
end
|
101
|
+
|
102
|
+
def build_callbacks(params)
|
103
|
+
return [] unless params.is_a?(Array)
|
104
|
+
|
105
|
+
add_auto_configurable_callbacks(params)
|
106
|
+
|
107
|
+
params.uniq! { |c| c[:callback_type] }
|
108
|
+
|
109
|
+
params.map do |conf|
|
110
|
+
conf.symbolize_keys!
|
111
|
+
if conf.key?(:callback_type)
|
112
|
+
callback_type = conf[:callback_type]
|
113
|
+
else
|
114
|
+
callback_type = conf.keys.first.to_sym
|
115
|
+
conf = conf.values.first.symbolize_keys!
|
116
|
+
end
|
117
|
+
|
118
|
+
klass = case callback_type.to_sym
|
119
|
+
when :wandb then Wandb::XGBoostCallback
|
120
|
+
when :evals_callback then EasyML::Models::XGBoost::EvalsCallback
|
121
|
+
when :progress_callback then EasyML::Models::XGBoost::ProgressCallback
|
122
|
+
end
|
123
|
+
raise "Unknown callback type #{callback_type}" unless klass.present?
|
124
|
+
|
125
|
+
klass.new(conf).tap do |instance|
|
126
|
+
instance.instance_variable_set(:@callback_type, callback_type)
|
127
|
+
instance.send(:model=, model) if instance.respond_to?(:model=)
|
128
|
+
end
|
129
|
+
end
|
130
|
+
end
|
131
|
+
|
132
|
+
def after_tuning
|
133
|
+
model.callbacks.each do |callback|
|
134
|
+
callback.after_tuning if callback.respond_to?(:after_tuning)
|
135
|
+
end
|
136
|
+
end
|
137
|
+
|
138
|
+
def prepare_callbacks(tuner)
|
139
|
+
set_wandb_project(tuner.project_name)
|
140
|
+
|
141
|
+
model.callbacks.each do |callback|
|
142
|
+
callback.prepare_callback(tuner) if callback.respond_to?(:prepare_callback)
|
143
|
+
end
|
144
|
+
end
|
145
|
+
|
146
|
+
def set_wandb_project(project_name)
|
147
|
+
wandb_callback = model.callbacks.detect { |cb| cb.class == Wandb::XGBoostCallback }
|
148
|
+
return unless wandb_callback.present?
|
149
|
+
wandb_callback.project_name = project_name
|
150
|
+
end
|
151
|
+
|
152
|
+
def get_wandb_project
|
153
|
+
wandb_callback = model.callbacks.detect { |cb| cb.class == Wandb::XGBoostCallback }
|
154
|
+
return nil unless wandb_callback.present?
|
155
|
+
wandb_callback.project_name
|
156
|
+
end
|
157
|
+
|
158
|
+
def delete_wandb_project
|
159
|
+
wandb_callback = model.callbacks.detect { |cb| cb.class == Wandb::XGBoostCallback }
|
160
|
+
return nil unless wandb_callback.present?
|
161
|
+
wandb_callback.project_name = nil
|
162
|
+
end
|
163
|
+
|
164
|
+
def is_fit?
|
165
|
+
@booster.present? && @booster.feature_names.any?
|
166
|
+
end
|
167
|
+
|
168
|
+
attr_accessor :progress_callback
|
169
|
+
|
170
|
+
def fit(tuning: false, x_train: nil, y_train: nil, x_valid: nil, y_valid: nil, &progress_block)
|
171
|
+
validate_objective
|
172
|
+
|
173
|
+
d_train, d_valid, = prepare_data if x_train.nil?
|
174
|
+
|
175
|
+
evals = [[d_train, "train"], [d_valid, "eval"]]
|
176
|
+
self.progress_callback = progress_block
|
177
|
+
set_default_wandb_project_name unless tuning
|
178
|
+
@booster = base_model.train(hyperparameters.to_h,
|
179
|
+
d_train,
|
180
|
+
evals: evals,
|
181
|
+
num_boost_round: hyperparameters["n_estimators"],
|
182
|
+
callbacks: model.callbacks,
|
183
|
+
early_stopping_rounds: hyperparameters.to_h.dig("early_stopping_rounds"))
|
184
|
+
delete_wandb_project unless tuning
|
185
|
+
return @booster
|
186
|
+
end
|
187
|
+
|
188
|
+
def set_default_wandb_project_name
|
189
|
+
return if get_wandb_project.present?
|
190
|
+
|
191
|
+
started_at = EasyML::Support::UTC.now
|
192
|
+
project_name = "#{model.name}_#{started_at.strftime("%Y_%m_%d_%H_%M_%S")}"
|
193
|
+
set_wandb_project(project_name)
|
194
|
+
end
|
195
|
+
|
196
|
+
def fit_in_batches(tuning: false, batch_size: 1024, batch_key: nil, batch_start: nil, batch_overlap: 1, checkpoint_dir: Rails.root.join("tmp", "xgboost_checkpoints"))
|
197
|
+
validate_objective
|
198
|
+
ensure_directory_exists(checkpoint_dir)
|
199
|
+
set_default_wandb_project_name unless tuning
|
200
|
+
|
201
|
+
# Prepare validation data
|
202
|
+
x_valid, y_valid = dataset.valid(split_ys: true)
|
203
|
+
d_valid = preprocess(x_valid, y_valid)
|
204
|
+
|
205
|
+
num_iterations = hyperparameters.to_h[:n_estimators]
|
206
|
+
early_stopping_rounds = hyperparameters.to_h[:early_stopping_rounds]
|
207
|
+
|
208
|
+
num_batches = dataset.train(batch_size: batch_size, batch_start: batch_start, batch_key: batch_key).count
|
209
|
+
iterations_per_batch = num_iterations / num_batches
|
210
|
+
stopping_points = (1..num_batches).to_a.map { |n| n * iterations_per_batch }
|
211
|
+
stopping_points[-1] = num_iterations
|
212
|
+
|
213
|
+
current_iteration = 0
|
214
|
+
current_batch = 0
|
215
|
+
|
216
|
+
callbacks = model.callbacks.nil? ? [] : model.callbacks.dup
|
217
|
+
callbacks << ::XGBoost::EvaluationMonitor.new(period: 1)
|
218
|
+
|
219
|
+
# Generate batches without loading full dataset
|
220
|
+
batches = dataset.train(split_ys: true, batch_size: batch_size, batch_start: batch_start, batch_key: batch_key)
|
221
|
+
prev_xs = []
|
222
|
+
prev_ys = []
|
223
|
+
|
224
|
+
while current_iteration < num_iterations
|
225
|
+
# Load the next batch
|
226
|
+
x_train, y_train = batches.next
|
227
|
+
|
228
|
+
# Add batch_overlap from previous batch if applicable
|
229
|
+
merged_x, merged_y = nil, nil
|
230
|
+
if prev_xs.any?
|
231
|
+
merged_x = Polars.concat([x_train] + prev_xs.flatten)
|
232
|
+
merged_y = Polars.concat([y_train] + prev_ys.flatten)
|
233
|
+
end
|
234
|
+
|
235
|
+
if batch_overlap > 0
|
236
|
+
prev_xs << [x_train]
|
237
|
+
prev_ys << [y_train]
|
238
|
+
if prev_xs.size > batch_overlap
|
239
|
+
prev_xs = prev_xs[1..]
|
240
|
+
prev_ys = prev_ys[1..]
|
241
|
+
end
|
242
|
+
end
|
243
|
+
|
244
|
+
if merged_x.present?
|
245
|
+
x_train = merged_x
|
246
|
+
y_train = merged_y
|
247
|
+
end
|
248
|
+
|
249
|
+
d_train = preprocess(x_train, y_train)
|
250
|
+
evals = [[d_train, "train"], [d_valid, "eval"]]
|
251
|
+
|
252
|
+
model_file = current_batch == 0 ? nil : checkpoint_dir.join("#{current_batch - 1}.json").to_s
|
253
|
+
|
254
|
+
@booster = booster_class.new(
|
255
|
+
params: hyperparameters.to_h.symbolize_keys,
|
256
|
+
cache: [d_train, d_valid],
|
257
|
+
model_file: model_file,
|
258
|
+
)
|
259
|
+
loop_callbacks = callbacks.dup
|
260
|
+
if early_stopping_rounds
|
261
|
+
loop_callbacks << ::XGBoost::EarlyStopping.new(rounds: early_stopping_rounds)
|
262
|
+
end
|
263
|
+
cb_container = ::XGBoost::CallbackContainer.new(loop_callbacks)
|
264
|
+
@booster = cb_container.before_training(@booster) if current_iteration == 0
|
265
|
+
|
266
|
+
stopping_point = stopping_points[current_batch]
|
267
|
+
while current_iteration < stopping_point
|
268
|
+
break if cb_container.before_iteration(@booster, current_iteration, d_train, evals)
|
269
|
+
@booster.update(d_train, current_iteration)
|
270
|
+
break if cb_container.after_iteration(@booster, current_iteration, d_train, evals)
|
271
|
+
current_iteration += 1
|
272
|
+
end
|
273
|
+
current_iteration = stopping_point # In case of early stopping
|
274
|
+
|
275
|
+
@booster.save_model(checkpoint_dir.join("#{current_batch}.json").to_s)
|
276
|
+
current_batch += 1
|
277
|
+
end
|
278
|
+
|
279
|
+
@booster = cb_container.after_training(@booster)
|
280
|
+
delete_wandb_project unless tuning
|
281
|
+
return @booster
|
282
|
+
end
|
283
|
+
|
284
|
+
def weights
|
285
|
+
@booster.save_model("tmp/xgboost_model.json")
|
286
|
+
@booster.get_dump
|
287
|
+
end
|
288
|
+
|
289
|
+
def predict(xs)
|
290
|
+
raise "No trained model! Train a model before calling predict" unless @booster.present?
|
291
|
+
raise "Cannot predict on nil — XGBoost" if xs.nil?
|
292
|
+
|
293
|
+
begin
|
294
|
+
y_pred = @booster.predict(preprocess(xs))
|
295
|
+
rescue StandardError => e
|
296
|
+
raise e unless e.message.match?(/Number of columns does not match/)
|
297
|
+
|
298
|
+
raise %(
|
299
|
+
>>>>><<<<<
|
300
|
+
XGBoost received predict with unexpected features!
|
301
|
+
>>>>><<<<<
|
302
|
+
|
303
|
+
Model expects features:
|
304
|
+
#{feature_names}
|
305
|
+
Model received features:
|
306
|
+
#{xs.columns}
|
307
|
+
)
|
308
|
+
end
|
309
|
+
|
310
|
+
case task.to_sym
|
311
|
+
when :classification
|
312
|
+
to_classification(y_pred)
|
313
|
+
else
|
314
|
+
y_pred
|
315
|
+
end
|
316
|
+
end
|
317
|
+
|
318
|
+
def predict_proba(data)
|
319
|
+
dmat = DMatrix.new(data)
|
320
|
+
y_pred = @booster.predict(dmat)
|
321
|
+
|
322
|
+
if y_pred.first.is_a?(Array)
|
323
|
+
# multiple classes
|
324
|
+
y_pred
|
325
|
+
else
|
326
|
+
y_pred.map { |v| [1 - v, v] }
|
327
|
+
end
|
328
|
+
end
|
329
|
+
|
330
|
+
def unload
|
331
|
+
@xgboost_model = nil
|
332
|
+
@booster = nil
|
333
|
+
end
|
334
|
+
|
335
|
+
def loaded?
|
336
|
+
@booster.present? && @booster.feature_names.any?
|
337
|
+
end
|
338
|
+
|
339
|
+
def load_model_file(path)
|
340
|
+
return if loaded?
|
341
|
+
|
342
|
+
initialize_model do
|
343
|
+
attrs = {
|
344
|
+
params: hyperparameters.to_h.symbolize_keys.compact,
|
345
|
+
model_file: path,
|
346
|
+
}.compact
|
347
|
+
booster_class.new(**attrs)
|
348
|
+
end
|
349
|
+
end
|
350
|
+
|
351
|
+
def external_model
|
352
|
+
@booster
|
353
|
+
end
|
354
|
+
|
355
|
+
def external_model=(booster)
|
356
|
+
@booster = booster
|
357
|
+
end
|
358
|
+
|
359
|
+
def model_changed?(prev_hash)
|
360
|
+
return false unless @booster.present? && @booster.feature_names.any?
|
361
|
+
|
362
|
+
current_model_hash = nil
|
363
|
+
Tempfile.create(["xgboost_model", ".json"]) do |tempfile|
|
364
|
+
@booster.save_model(tempfile.path)
|
365
|
+
tempfile.rewind
|
366
|
+
JSON.parse(tempfile.read)
|
367
|
+
current_model_hash = Digest::SHA256.file(tempfile.path).hexdigest
|
368
|
+
end
|
369
|
+
current_model_hash != prev_hash
|
370
|
+
end
|
371
|
+
|
372
|
+
def save_model_file(path)
|
373
|
+
path = path.to_s
|
374
|
+
ensure_directory_exists(File.dirname(path))
|
375
|
+
extension = Pathname.new(path).extname.gsub("\.", "")
|
376
|
+
path = "#{path}.json" unless extension == "json"
|
377
|
+
|
378
|
+
@booster.save_model(path)
|
379
|
+
path
|
380
|
+
end
|
381
|
+
|
382
|
+
def feature_names
|
383
|
+
@booster.feature_names
|
384
|
+
end
|
385
|
+
|
386
|
+
def feature_importances
|
387
|
+
score = @booster.score(importance_type: @importance_type || "gain")
|
388
|
+
scores = @booster.feature_names.map { |k| score[k] || 0.0 }
|
389
|
+
total = scores.sum.to_f
|
390
|
+
fi = scores.map { |s| s / total }
|
391
|
+
@booster.feature_names.zip(fi).to_h
|
392
|
+
end
|
393
|
+
|
394
|
+
def base_model
|
395
|
+
::XGBoost
|
396
|
+
end
|
397
|
+
|
398
|
+
def prepare_data
|
399
|
+
if @d_train.nil?
|
400
|
+
x_sample, y_sample = dataset.train(split_ys: true, limit: 5)
|
401
|
+
preprocess(x_sample, y_sample) # Ensure we fail fast if the dataset is misconfigured
|
402
|
+
x_train, y_train = dataset.train(split_ys: true)
|
403
|
+
x_valid, y_valid = dataset.valid(split_ys: true)
|
404
|
+
x_test, y_test = dataset.test(split_ys: true)
|
405
|
+
@d_train = preprocess(x_train, y_train)
|
406
|
+
@d_valid = preprocess(x_valid, y_valid)
|
407
|
+
@d_test = preprocess(x_test, y_test)
|
408
|
+
end
|
409
|
+
|
410
|
+
[@d_train, @d_valid, @d_test]
|
411
|
+
end
|
412
|
+
|
413
|
+
def preprocess(xs, ys = nil)
|
414
|
+
return xs if xs.is_a?(::XGBoost::DMatrix)
|
415
|
+
|
416
|
+
orig_xs = xs.dup
|
417
|
+
column_names = xs.columns
|
418
|
+
xs = _preprocess(xs)
|
419
|
+
ys = ys.nil? ? nil : _preprocess(ys).flatten
|
420
|
+
kwargs = { label: ys }.compact
|
421
|
+
begin
|
422
|
+
::XGBoost::DMatrix.new(xs, **kwargs).tap do |dmat|
|
423
|
+
dmat.feature_names = column_names
|
424
|
+
end
|
425
|
+
rescue StandardError => e
|
426
|
+
problematic_columns = orig_xs.schema.select { |k, v| [Polars::Categorical, Polars::String].include?(v) }
|
427
|
+
problematic_xs = orig_xs.select(problematic_columns.keys)
|
428
|
+
raise %(
|
429
|
+
Error building data for XGBoost.
|
430
|
+
Apply preprocessing to columns
|
431
|
+
>>>>><<<<<
|
432
|
+
#{problematic_columns.keys}
|
433
|
+
>>>>><<<<<
|
434
|
+
A sample of your dataset:
|
435
|
+
#{problematic_xs[0..5]}
|
436
|
+
|
437
|
+
#{if ys.present?
|
438
|
+
%(
|
439
|
+
This may also be due to your targets:
|
440
|
+
#{ys[0..5]}
|
441
|
+
)
|
442
|
+
else
|
443
|
+
""
|
444
|
+
end}
|
445
|
+
)
|
446
|
+
end
|
447
|
+
end
|
448
|
+
|
449
|
+
def self.hyperparameter_constants
|
450
|
+
EasyML::Models::Hyperparameters::XGBoost.hyperparameter_constants
|
451
|
+
end
|
452
|
+
|
453
|
+
private
|
454
|
+
|
455
|
+
def booster_class
|
456
|
+
::XGBoost::Booster
|
457
|
+
end
|
458
|
+
|
459
|
+
def d_matrix_class
|
460
|
+
::XGBoost::DMatrix
|
461
|
+
end
|
462
|
+
|
463
|
+
def model_class
|
464
|
+
::XGBoost::Model
|
465
|
+
end
|
466
|
+
|
467
|
+
def fit_batch(d_train, current_iteration, evals, cb_container)
|
468
|
+
if @booster.nil?
|
469
|
+
@booster = booster_class.new(params: @hyperparameters.to_h, cache: [d_train] + evals.map do |d|
|
470
|
+
d[0]
|
471
|
+
end, early_stopping_rounds: @hyperparameters.to_h.dig(:early_stopping_rounds))
|
472
|
+
end
|
473
|
+
|
474
|
+
@booster = cb_container.before_training(@booster)
|
475
|
+
cb_container.before_iteration(@booster, current_iteration, d_train, evals)
|
476
|
+
@booster.update(d_train, current_iteration)
|
477
|
+
cb_container.after_iteration(@booster, current_iteration, d_train, evals)
|
478
|
+
end
|
479
|
+
|
480
|
+
def _preprocess(df)
|
481
|
+
return df if df.is_a?(Array)
|
482
|
+
|
483
|
+
df.to_a.map do |row|
|
484
|
+
row.values.map do |value|
|
485
|
+
case value
|
486
|
+
when Time
|
487
|
+
value.to_i # Convert Time to Unix timestamp
|
488
|
+
when Date
|
489
|
+
value.to_time.to_i # Convert Date to Unix timestamp
|
490
|
+
when String
|
491
|
+
value
|
492
|
+
when TrueClass, FalseClass
|
493
|
+
value ? 1.0 : 0.0 # Convert booleans to 1.0 and 0.0
|
494
|
+
when Integer
|
495
|
+
value
|
496
|
+
else
|
497
|
+
value.to_f # Ensure everything else is converted to a float
|
498
|
+
end
|
499
|
+
end
|
500
|
+
end
|
501
|
+
end
|
502
|
+
|
503
|
+
def initialize_model
|
504
|
+
@xgboost_model = model_class.new(n_estimators: @hyperparameters.to_h.dig(:n_estimators))
|
505
|
+
if block_given?
|
506
|
+
@booster = yield
|
507
|
+
else
|
508
|
+
attrs = {
|
509
|
+
params: hyperparameters.to_h.symbolize_keys,
|
510
|
+
}.deep_compact
|
511
|
+
@booster = booster_class.new(**attrs)
|
512
|
+
end
|
513
|
+
@xgboost_model.instance_variable_set(:@booster, @booster)
|
514
|
+
end
|
515
|
+
|
516
|
+
def validate_objective
|
517
|
+
objective = hyperparameters.objective
|
518
|
+
unless task.present?
|
519
|
+
raise ArgumentError,
|
520
|
+
"cannot train model without task. Please specify either regression or classification (model.task = :regression)"
|
521
|
+
end
|
522
|
+
|
523
|
+
case task.to_sym
|
524
|
+
when :classification
|
525
|
+
_, ys = dataset.data(split_ys: true)
|
526
|
+
classification_type = ys[ys.columns.first].uniq.count <= 2 ? :binary : :multi_class
|
527
|
+
allowed_objectives = OBJECTIVES[:classification][classification_type]
|
528
|
+
else
|
529
|
+
allowed_objectives = OBJECTIVES[task.to_sym]
|
530
|
+
end
|
531
|
+
return if allowed_objectives.map(&:to_sym).include?(objective.to_sym)
|
532
|
+
|
533
|
+
raise ArgumentError,
|
534
|
+
"cannot use #{objective} for #{task} task. Allowed objectives are: #{allowed_objectives.join(", ")}"
|
535
|
+
end
|
536
|
+
|
537
|
+
def to_classification(y_pred)
|
538
|
+
if y_pred.first.is_a?(Array)
|
539
|
+
# multiple classes
|
540
|
+
y_pred.map do |v|
|
541
|
+
v.map.with_index.max_by { |v2, _| v2 }.last
|
542
|
+
end
|
543
|
+
else
|
544
|
+
y_pred.map { |v| v > 0.5 ? 1 : 0 }
|
545
|
+
end
|
546
|
+
end
|
7
547
|
end
|
8
548
|
end
|
9
549
|
end
|
@@ -0,0 +1,44 @@
|
|
1
|
+
# == Schema Information
|
2
|
+
#
|
3
|
+
# Table name: easy_ml_predictions
|
4
|
+
#
|
5
|
+
# id :bigint not null, primary key
|
6
|
+
# model_id :bigint not null
|
7
|
+
# model_history_id :bigint
|
8
|
+
# prediction_type :string
|
9
|
+
# prediction_value :jsonb
|
10
|
+
# raw_input :jsonb
|
11
|
+
# normalized_input :jsonb
|
12
|
+
# created_at :datetime not null
|
13
|
+
# updated_at :datetime not null
|
14
|
+
#
|
15
|
+
module EasyML
|
16
|
+
class Prediction < ActiveRecord::Base
|
17
|
+
self.table_name = "easy_ml_predictions"
|
18
|
+
|
19
|
+
belongs_to :model
|
20
|
+
belongs_to :model_history, optional: true
|
21
|
+
|
22
|
+
validates :model_id, presence: true
|
23
|
+
validates :prediction_type, presence: true, inclusion: { in: %w[regression classification] }
|
24
|
+
validates :prediction_value, presence: true
|
25
|
+
validates :raw_input, presence: true
|
26
|
+
validates :normalized_input, presence: true
|
27
|
+
|
28
|
+
def prediction
|
29
|
+
prediction_value["value"]
|
30
|
+
end
|
31
|
+
|
32
|
+
def probabilities
|
33
|
+
prediction_value["probabilities"]
|
34
|
+
end
|
35
|
+
|
36
|
+
def regression?
|
37
|
+
prediction_type == "regression"
|
38
|
+
end
|
39
|
+
|
40
|
+
def classification?
|
41
|
+
prediction_type == "classification"
|
42
|
+
end
|
43
|
+
end
|
44
|
+
end
|