easy_ml 0.1.4 → 0.2.0.pre.rc1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +234 -26
- data/Rakefile +45 -0
- data/app/controllers/easy_ml/application_controller.rb +67 -0
- data/app/controllers/easy_ml/columns_controller.rb +38 -0
- data/app/controllers/easy_ml/datasets_controller.rb +156 -0
- data/app/controllers/easy_ml/datasources_controller.rb +88 -0
- data/app/controllers/easy_ml/deploys_controller.rb +20 -0
- data/app/controllers/easy_ml/models_controller.rb +151 -0
- data/app/controllers/easy_ml/retraining_runs_controller.rb +19 -0
- data/app/controllers/easy_ml/settings_controller.rb +59 -0
- data/app/frontend/components/AlertProvider.tsx +108 -0
- data/app/frontend/components/DatasetPreview.tsx +161 -0
- data/app/frontend/components/EmptyState.tsx +28 -0
- data/app/frontend/components/ModelCard.tsx +255 -0
- data/app/frontend/components/ModelDetails.tsx +334 -0
- data/app/frontend/components/ModelForm.tsx +384 -0
- data/app/frontend/components/Navigation.tsx +300 -0
- data/app/frontend/components/Pagination.tsx +72 -0
- data/app/frontend/components/Popover.tsx +55 -0
- data/app/frontend/components/PredictionStream.tsx +105 -0
- data/app/frontend/components/ScheduleModal.tsx +726 -0
- data/app/frontend/components/SearchInput.tsx +23 -0
- data/app/frontend/components/SearchableSelect.tsx +132 -0
- data/app/frontend/components/dataset/AutosaveIndicator.tsx +39 -0
- data/app/frontend/components/dataset/ColumnConfigModal.tsx +431 -0
- data/app/frontend/components/dataset/ColumnFilters.tsx +256 -0
- data/app/frontend/components/dataset/ColumnList.tsx +101 -0
- data/app/frontend/components/dataset/FeatureConfigPopover.tsx +57 -0
- data/app/frontend/components/dataset/FeaturePicker.tsx +205 -0
- data/app/frontend/components/dataset/PreprocessingConfig.tsx +704 -0
- data/app/frontend/components/dataset/SplitConfigurator.tsx +120 -0
- data/app/frontend/components/dataset/splitters/DateSplitter.tsx +58 -0
- data/app/frontend/components/dataset/splitters/KFoldSplitter.tsx +68 -0
- data/app/frontend/components/dataset/splitters/LeavePOutSplitter.tsx +29 -0
- data/app/frontend/components/dataset/splitters/PredefinedSplitter.tsx +146 -0
- data/app/frontend/components/dataset/splitters/RandomSplitter.tsx +85 -0
- data/app/frontend/components/dataset/splitters/StratifiedSplitter.tsx +79 -0
- data/app/frontend/components/dataset/splitters/constants.ts +77 -0
- data/app/frontend/components/dataset/splitters/types.ts +168 -0
- data/app/frontend/components/dataset/splitters/utils.ts +53 -0
- data/app/frontend/components/features/CodeEditor.tsx +46 -0
- data/app/frontend/components/features/DataPreview.tsx +150 -0
- data/app/frontend/components/features/FeatureCard.tsx +88 -0
- data/app/frontend/components/features/FeatureForm.tsx +235 -0
- data/app/frontend/components/features/FeatureGroupCard.tsx +54 -0
- data/app/frontend/components/settings/PluginSettings.tsx +81 -0
- data/app/frontend/components/ui/badge.tsx +44 -0
- data/app/frontend/components/ui/collapsible.tsx +9 -0
- data/app/frontend/components/ui/scroll-area.tsx +46 -0
- data/app/frontend/components/ui/separator.tsx +29 -0
- data/app/frontend/entrypoints/App.tsx +40 -0
- data/app/frontend/entrypoints/Application.tsx +24 -0
- data/app/frontend/hooks/useAutosave.ts +61 -0
- data/app/frontend/layouts/Layout.tsx +38 -0
- data/app/frontend/lib/utils.ts +6 -0
- data/app/frontend/mockData.ts +272 -0
- data/app/frontend/pages/DatasetDetailsPage.tsx +103 -0
- data/app/frontend/pages/DatasetsPage.tsx +261 -0
- data/app/frontend/pages/DatasourceFormPage.tsx +147 -0
- data/app/frontend/pages/DatasourcesPage.tsx +261 -0
- data/app/frontend/pages/EditModelPage.tsx +45 -0
- data/app/frontend/pages/EditTransformationPage.tsx +56 -0
- data/app/frontend/pages/ModelsPage.tsx +115 -0
- data/app/frontend/pages/NewDatasetPage.tsx +366 -0
- data/app/frontend/pages/NewModelPage.tsx +45 -0
- data/app/frontend/pages/NewTransformationPage.tsx +43 -0
- data/app/frontend/pages/SettingsPage.tsx +272 -0
- data/app/frontend/pages/ShowModelPage.tsx +30 -0
- data/app/frontend/pages/TransformationsPage.tsx +95 -0
- data/app/frontend/styles/application.css +100 -0
- data/app/frontend/types/dataset.ts +146 -0
- data/app/frontend/types/datasource.ts +33 -0
- data/app/frontend/types/preprocessing.ts +1 -0
- data/app/frontend/types.ts +113 -0
- data/app/helpers/easy_ml/application_helper.rb +10 -0
- data/app/jobs/easy_ml/application_job.rb +21 -0
- data/app/jobs/easy_ml/batch_job.rb +46 -0
- data/app/jobs/easy_ml/compute_feature_job.rb +19 -0
- data/app/jobs/easy_ml/deploy_job.rb +13 -0
- data/app/jobs/easy_ml/finalize_feature_job.rb +15 -0
- data/app/jobs/easy_ml/refresh_dataset_job.rb +32 -0
- data/app/jobs/easy_ml/schedule_retraining_job.rb +11 -0
- data/app/jobs/easy_ml/sync_datasource_job.rb +17 -0
- data/app/jobs/easy_ml/training_job.rb +62 -0
- data/app/models/easy_ml/adapters/base_adapter.rb +45 -0
- data/app/models/easy_ml/adapters/polars_adapter.rb +77 -0
- data/app/models/easy_ml/cleaner.rb +82 -0
- data/app/models/easy_ml/column.rb +124 -0
- data/app/models/easy_ml/column_history.rb +30 -0
- data/app/models/easy_ml/column_list.rb +122 -0
- data/app/models/easy_ml/concerns/configurable.rb +61 -0
- data/app/models/easy_ml/concerns/versionable.rb +19 -0
- data/app/models/easy_ml/dataset.rb +767 -0
- data/app/models/easy_ml/dataset_history.rb +56 -0
- data/app/models/easy_ml/datasource.rb +182 -0
- data/app/models/easy_ml/datasource_history.rb +24 -0
- data/app/models/easy_ml/datasources/base_datasource.rb +54 -0
- data/app/models/easy_ml/datasources/file_datasource.rb +58 -0
- data/app/models/easy_ml/datasources/polars_datasource.rb +89 -0
- data/app/models/easy_ml/datasources/s3_datasource.rb +97 -0
- data/app/models/easy_ml/deploy.rb +114 -0
- data/app/models/easy_ml/event.rb +79 -0
- data/app/models/easy_ml/feature.rb +437 -0
- data/app/models/easy_ml/feature_history.rb +38 -0
- data/app/models/easy_ml/model.rb +575 -41
- data/app/models/easy_ml/model_file.rb +133 -0
- data/app/models/easy_ml/model_file_history.rb +24 -0
- data/app/models/easy_ml/model_history.rb +51 -0
- data/app/models/easy_ml/models/base_model.rb +58 -0
- data/app/models/easy_ml/models/hyperparameters/base.rb +99 -0
- data/app/models/easy_ml/models/hyperparameters/xgboost/dart.rb +82 -0
- data/app/models/easy_ml/models/hyperparameters/xgboost/gblinear.rb +82 -0
- data/app/models/easy_ml/models/hyperparameters/xgboost/gbtree.rb +97 -0
- data/app/models/easy_ml/models/hyperparameters/xgboost.rb +71 -0
- data/app/models/easy_ml/models/xgboost/evals_callback.rb +138 -0
- data/app/models/easy_ml/models/xgboost/progress_callback.rb +39 -0
- data/app/models/easy_ml/models/xgboost.rb +544 -5
- data/app/models/easy_ml/prediction.rb +44 -0
- data/app/models/easy_ml/retraining_job.rb +278 -0
- data/app/models/easy_ml/retraining_run.rb +184 -0
- data/app/models/easy_ml/settings.rb +37 -0
- data/app/models/easy_ml/splitter.rb +90 -0
- data/app/models/easy_ml/splitters/base_splitter.rb +28 -0
- data/app/models/easy_ml/splitters/date_splitter.rb +91 -0
- data/app/models/easy_ml/splitters/predefined_splitter.rb +74 -0
- data/app/models/easy_ml/splitters/random_splitter.rb +82 -0
- data/app/models/easy_ml/tuner_job.rb +56 -0
- data/app/models/easy_ml/tuner_run.rb +31 -0
- data/app/models/splitter_history.rb +6 -0
- data/app/serializers/easy_ml/column_serializer.rb +27 -0
- data/app/serializers/easy_ml/dataset_serializer.rb +73 -0
- data/app/serializers/easy_ml/datasource_serializer.rb +64 -0
- data/app/serializers/easy_ml/feature_serializer.rb +27 -0
- data/app/serializers/easy_ml/model_serializer.rb +90 -0
- data/app/serializers/easy_ml/retraining_job_serializer.rb +22 -0
- data/app/serializers/easy_ml/retraining_run_serializer.rb +39 -0
- data/app/serializers/easy_ml/settings_serializer.rb +9 -0
- data/app/views/layouts/easy_ml/application.html.erb +15 -0
- data/config/initializers/resque.rb +3 -0
- data/config/resque-pool.yml +6 -0
- data/config/routes.rb +39 -0
- data/config/spring.rb +1 -0
- data/config/vite.json +15 -0
- data/lib/easy_ml/configuration.rb +64 -0
- data/lib/easy_ml/core/evaluators/base_evaluator.rb +53 -0
- data/lib/easy_ml/core/evaluators/classification_evaluators.rb +126 -0
- data/lib/easy_ml/core/evaluators/regression_evaluators.rb +66 -0
- data/lib/easy_ml/core/model_evaluator.rb +161 -89
- data/lib/easy_ml/core/tuner/adapters/base_adapter.rb +28 -18
- data/lib/easy_ml/core/tuner/adapters/xgboost_adapter.rb +4 -25
- data/lib/easy_ml/core/tuner.rb +123 -62
- data/lib/easy_ml/core.rb +0 -3
- data/lib/easy_ml/core_ext/hash.rb +24 -0
- data/lib/easy_ml/core_ext/pathname.rb +11 -5
- data/lib/easy_ml/data/date_converter.rb +90 -0
- data/lib/easy_ml/data/filter_extensions.rb +31 -0
- data/lib/easy_ml/data/polars_column.rb +126 -0
- data/lib/easy_ml/data/polars_reader.rb +297 -0
- data/lib/easy_ml/data/preprocessor.rb +280 -142
- data/lib/easy_ml/data/simple_imputer.rb +255 -0
- data/lib/easy_ml/data/splits/file_split.rb +252 -0
- data/lib/easy_ml/data/splits/in_memory_split.rb +54 -0
- data/lib/easy_ml/data/splits/split.rb +95 -0
- data/lib/easy_ml/data/splits.rb +9 -0
- data/lib/easy_ml/data/statistics_learner.rb +93 -0
- data/lib/easy_ml/data/synced_directory.rb +341 -0
- data/lib/easy_ml/data.rb +6 -2
- data/lib/easy_ml/engine.rb +105 -6
- data/lib/easy_ml/feature_store.rb +227 -0
- data/lib/easy_ml/features.rb +61 -0
- data/lib/easy_ml/initializers/inflections.rb +17 -3
- data/lib/easy_ml/logging.rb +2 -2
- data/lib/easy_ml/predict.rb +74 -0
- data/lib/easy_ml/railtie/generators/migration/migration_generator.rb +192 -36
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_column_histories.rb.tt +9 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_columns.rb.tt +25 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_dataset_histories.rb.tt +9 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_datasets.rb.tt +31 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_datasource_histories.rb.tt +9 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_datasources.rb.tt +16 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_deploys.rb.tt +24 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_events.rb.tt +20 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_feature_histories.rb.tt +14 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_features.rb.tt +32 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_model_file_histories.rb.tt +9 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_model_files.rb.tt +17 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_model_histories.rb.tt +9 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_models.rb.tt +20 -9
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_predictions.rb.tt +17 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_retraining_jobs.rb.tt +77 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_settings.rb.tt +9 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_splitter_histories.rb.tt +9 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_splitters.rb.tt +15 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_tuner_jobs.rb.tt +40 -0
- data/lib/easy_ml/support/est.rb +5 -1
- data/lib/easy_ml/support/file_rotate.rb +79 -15
- data/lib/easy_ml/support/file_support.rb +9 -0
- data/lib/easy_ml/support/local_file.rb +24 -0
- data/lib/easy_ml/support/lockable.rb +62 -0
- data/lib/easy_ml/support/synced_file.rb +103 -0
- data/lib/easy_ml/support/utc.rb +5 -1
- data/lib/easy_ml/support.rb +6 -3
- data/lib/easy_ml/version.rb +4 -1
- data/lib/easy_ml.rb +7 -2
- metadata +355 -72
- data/app/models/easy_ml/models.rb +0 -5
- data/lib/easy_ml/core/model.rb +0 -30
- data/lib/easy_ml/core/model_core.rb +0 -181
- data/lib/easy_ml/core/models/hyperparameters/base.rb +0 -34
- data/lib/easy_ml/core/models/hyperparameters/xgboost.rb +0 -19
- data/lib/easy_ml/core/models/xgboost.rb +0 -10
- data/lib/easy_ml/core/models/xgboost_core.rb +0 -220
- data/lib/easy_ml/core/models.rb +0 -10
- data/lib/easy_ml/core/uploaders/model_uploader.rb +0 -24
- data/lib/easy_ml/core/uploaders.rb +0 -7
- data/lib/easy_ml/data/dataloader.rb +0 -6
- data/lib/easy_ml/data/dataset/data/preprocessor/statistics.json +0 -31
- data/lib/easy_ml/data/dataset/data/sample_info.json +0 -1
- data/lib/easy_ml/data/dataset/dataset/files/sample_info.json +0 -1
- data/lib/easy_ml/data/dataset/splits/file_split.rb +0 -140
- data/lib/easy_ml/data/dataset/splits/in_memory_split.rb +0 -49
- data/lib/easy_ml/data/dataset/splits/split.rb +0 -98
- data/lib/easy_ml/data/dataset/splits.rb +0 -11
- data/lib/easy_ml/data/dataset/splitters/date_splitter.rb +0 -43
- data/lib/easy_ml/data/dataset/splitters.rb +0 -9
- data/lib/easy_ml/data/dataset.rb +0 -430
- data/lib/easy_ml/data/datasource/datasource_factory.rb +0 -60
- data/lib/easy_ml/data/datasource/file_datasource.rb +0 -40
- data/lib/easy_ml/data/datasource/merged_datasource.rb +0 -64
- data/lib/easy_ml/data/datasource/polars_datasource.rb +0 -41
- data/lib/easy_ml/data/datasource/s3_datasource.rb +0 -89
- data/lib/easy_ml/data/datasource.rb +0 -33
- data/lib/easy_ml/data/preprocessor/preprocessor.rb +0 -205
- data/lib/easy_ml/data/preprocessor/simple_imputer.rb +0 -402
- data/lib/easy_ml/deployment.rb +0 -5
- data/lib/easy_ml/support/synced_directory.rb +0 -134
- data/lib/easy_ml/transforms.rb +0 -29
- /data/{lib/easy_ml/core → app/models/easy_ml}/models/hyperparameters.rb +0 -0
@@ -1,10 +1,549 @@
|
|
1
|
-
|
2
|
-
|
1
|
+
# == Schema Information
|
2
|
+
#
|
3
|
+
# Table name: easy_ml_models
|
4
|
+
#
|
5
|
+
# id :bigint not null, primary key
|
6
|
+
# name :string not null
|
7
|
+
# model_type :string
|
8
|
+
# status :string
|
9
|
+
# dataset_id :bigint
|
10
|
+
# configuration :json
|
11
|
+
# version :string not null
|
12
|
+
# root_dir :string
|
13
|
+
# file :json
|
14
|
+
# created_at :datetime not null
|
15
|
+
# updated_at :datetime not null
|
3
16
|
module EasyML
|
4
17
|
module Models
|
5
|
-
class XGBoost <
|
6
|
-
|
7
|
-
|
18
|
+
class XGBoost < BaseModel
|
19
|
+
Hyperparameters = EasyML::Models::Hyperparameters::XGBoost
|
20
|
+
|
21
|
+
OBJECTIVES = {
|
22
|
+
classification: {
|
23
|
+
binary: %w[binary:logistic binary:hinge],
|
24
|
+
multiclass: %w[multi:softmax multi:softprob],
|
25
|
+
},
|
26
|
+
regression: %w[reg:squarederror reg:logistic],
|
27
|
+
}
|
28
|
+
|
29
|
+
OBJECTIVES_FRONTEND = {
|
30
|
+
classification: [
|
31
|
+
{ value: "binary:logistic", label: "Binary Logistic", description: "For binary classification" },
|
32
|
+
{ value: "binary:hinge", label: "Binary Hinge", description: "For binary classification with hinge loss" },
|
33
|
+
{ value: "multi:softmax", label: "Multiclass Softmax", description: "For multiclass classification" },
|
34
|
+
{ value: "multi:softprob", label: "Multiclass Probability",
|
35
|
+
description: "For multiclass classification with probability output" },
|
36
|
+
],
|
37
|
+
regression: [
|
38
|
+
{ value: "reg:squarederror", label: "Squared Error", description: "For regression with squared loss" },
|
39
|
+
{ value: "reg:logistic", label: "Logistic", description: "For regression with logistic loss" },
|
40
|
+
],
|
41
|
+
}
|
42
|
+
|
43
|
+
add_configuration_attributes :early_stopping_rounds
|
44
|
+
attr_accessor :xgboost_model, :booster
|
45
|
+
|
46
|
+
def build_hyperparameters(params)
|
47
|
+
params = {} if params.nil?
|
48
|
+
return nil unless params.is_a?(Hash)
|
49
|
+
|
50
|
+
params.to_h.symbolize_keys!
|
51
|
+
|
52
|
+
params[:booster] = :gbtree unless params.key?(:booster)
|
53
|
+
|
54
|
+
klass = case params[:booster].to_sym
|
55
|
+
when :gbtree
|
56
|
+
Hyperparameters::GBTree
|
57
|
+
when :dart
|
58
|
+
Hyperparameters::Dart
|
59
|
+
when :gblinear
|
60
|
+
Hyperparameters::GBLinear
|
61
|
+
else
|
62
|
+
raise "Unknown booster type: #{booster}"
|
63
|
+
end
|
64
|
+
raise "Unknown booster type #{booster}" unless klass.present?
|
65
|
+
|
66
|
+
overrides = {
|
67
|
+
objective: model.objective,
|
68
|
+
}
|
69
|
+
params.merge!(overrides)
|
70
|
+
|
71
|
+
klass.new(params)
|
72
|
+
end
|
73
|
+
|
74
|
+
def add_auto_configurable_callbacks(params)
|
75
|
+
if EasyML::Configuration.wandb_api_key.present?
|
76
|
+
params.map!(&:deep_symbolize_keys)
|
77
|
+
unless params.any? { |c| c[:callback_type]&.to_sym == :wandb }
|
78
|
+
params << {
|
79
|
+
callback_type: :wandb,
|
80
|
+
project_name: model.name,
|
81
|
+
log_feature_importance: false,
|
82
|
+
define_metric: false,
|
83
|
+
}
|
84
|
+
end
|
85
|
+
|
86
|
+
unless params.any? { |c| c[:callback_type]&.to_sym == :evals_callback }
|
87
|
+
params << {
|
88
|
+
callback_type: :evals_callback,
|
89
|
+
}
|
90
|
+
end
|
91
|
+
|
92
|
+
unless params.any? { |c| c[:callback_type]&.to_sym == :progress_callback }
|
93
|
+
params << {
|
94
|
+
callback_type: :progress_callback,
|
95
|
+
}
|
96
|
+
end
|
97
|
+
|
98
|
+
params.sort_by! { |c| c[:callback_type] == :evals_callback ? 0 : 1 }
|
99
|
+
end
|
100
|
+
end
|
101
|
+
|
102
|
+
def build_callbacks(params)
|
103
|
+
return [] unless params.is_a?(Array)
|
104
|
+
|
105
|
+
add_auto_configurable_callbacks(params)
|
106
|
+
|
107
|
+
params.uniq! { |c| c[:callback_type] }
|
108
|
+
|
109
|
+
params.map do |conf|
|
110
|
+
conf.symbolize_keys!
|
111
|
+
if conf.key?(:callback_type)
|
112
|
+
callback_type = conf[:callback_type]
|
113
|
+
else
|
114
|
+
callback_type = conf.keys.first.to_sym
|
115
|
+
conf = conf.values.first.symbolize_keys!
|
116
|
+
end
|
117
|
+
|
118
|
+
klass = case callback_type.to_sym
|
119
|
+
when :wandb then Wandb::XGBoostCallback
|
120
|
+
when :evals_callback then EasyML::Models::XGBoost::EvalsCallback
|
121
|
+
when :progress_callback then EasyML::Models::XGBoost::ProgressCallback
|
122
|
+
end
|
123
|
+
raise "Unknown callback type #{callback_type}" unless klass.present?
|
124
|
+
|
125
|
+
klass.new(conf).tap do |instance|
|
126
|
+
instance.instance_variable_set(:@callback_type, callback_type)
|
127
|
+
instance.send(:model=, model) if instance.respond_to?(:model=)
|
128
|
+
end
|
129
|
+
end
|
130
|
+
end
|
131
|
+
|
132
|
+
def after_tuning
|
133
|
+
model.callbacks.each do |callback|
|
134
|
+
callback.after_tuning if callback.respond_to?(:after_tuning)
|
135
|
+
end
|
136
|
+
end
|
137
|
+
|
138
|
+
def prepare_callbacks(tuner)
|
139
|
+
set_wandb_project(tuner.project_name)
|
140
|
+
|
141
|
+
model.callbacks.each do |callback|
|
142
|
+
callback.prepare_callback(tuner) if callback.respond_to?(:prepare_callback)
|
143
|
+
end
|
144
|
+
end
|
145
|
+
|
146
|
+
def set_wandb_project(project_name)
|
147
|
+
wandb_callback = model.callbacks.detect { |cb| cb.class == Wandb::XGBoostCallback }
|
148
|
+
return unless wandb_callback.present?
|
149
|
+
wandb_callback.project_name = project_name
|
150
|
+
end
|
151
|
+
|
152
|
+
def get_wandb_project
|
153
|
+
wandb_callback = model.callbacks.detect { |cb| cb.class == Wandb::XGBoostCallback }
|
154
|
+
return nil unless wandb_callback.present?
|
155
|
+
wandb_callback.project_name
|
156
|
+
end
|
157
|
+
|
158
|
+
def delete_wandb_project
|
159
|
+
wandb_callback = model.callbacks.detect { |cb| cb.class == Wandb::XGBoostCallback }
|
160
|
+
return nil unless wandb_callback.present?
|
161
|
+
wandb_callback.project_name = nil
|
162
|
+
end
|
163
|
+
|
164
|
+
def is_fit?
|
165
|
+
@booster.present? && @booster.feature_names.any?
|
166
|
+
end
|
167
|
+
|
168
|
+
attr_accessor :progress_callback
|
169
|
+
|
170
|
+
def fit(tuning: false, x_train: nil, y_train: nil, x_valid: nil, y_valid: nil, &progress_block)
|
171
|
+
validate_objective
|
172
|
+
|
173
|
+
d_train, d_valid, = prepare_data if x_train.nil?
|
174
|
+
|
175
|
+
evals = [[d_train, "train"], [d_valid, "eval"]]
|
176
|
+
self.progress_callback = progress_block
|
177
|
+
set_default_wandb_project_name unless tuning
|
178
|
+
@booster = base_model.train(hyperparameters.to_h,
|
179
|
+
d_train,
|
180
|
+
evals: evals,
|
181
|
+
num_boost_round: hyperparameters["n_estimators"],
|
182
|
+
callbacks: model.callbacks,
|
183
|
+
early_stopping_rounds: hyperparameters.to_h.dig("early_stopping_rounds"))
|
184
|
+
delete_wandb_project unless tuning
|
185
|
+
return @booster
|
186
|
+
end
|
187
|
+
|
188
|
+
def set_default_wandb_project_name
|
189
|
+
return if get_wandb_project.present?
|
190
|
+
|
191
|
+
started_at = EasyML::Support::UTC.now
|
192
|
+
project_name = "#{model.name}_#{started_at.strftime("%Y_%m_%d_%H_%M_%S")}"
|
193
|
+
set_wandb_project(project_name)
|
194
|
+
end
|
195
|
+
|
196
|
+
def fit_in_batches(tuning: false, batch_size: 1024, batch_key: nil, batch_start: nil, batch_overlap: 1, checkpoint_dir: Rails.root.join("tmp", "xgboost_checkpoints"))
|
197
|
+
validate_objective
|
198
|
+
ensure_directory_exists(checkpoint_dir)
|
199
|
+
set_default_wandb_project_name unless tuning
|
200
|
+
|
201
|
+
# Prepare validation data
|
202
|
+
x_valid, y_valid = dataset.valid(split_ys: true)
|
203
|
+
d_valid = preprocess(x_valid, y_valid)
|
204
|
+
|
205
|
+
num_iterations = hyperparameters.to_h[:n_estimators]
|
206
|
+
early_stopping_rounds = hyperparameters.to_h[:early_stopping_rounds]
|
207
|
+
|
208
|
+
num_batches = dataset.train(batch_size: batch_size, batch_start: batch_start, batch_key: batch_key).count
|
209
|
+
iterations_per_batch = num_iterations / num_batches
|
210
|
+
stopping_points = (1..num_batches).to_a.map { |n| n * iterations_per_batch }
|
211
|
+
stopping_points[-1] = num_iterations
|
212
|
+
|
213
|
+
current_iteration = 0
|
214
|
+
current_batch = 0
|
215
|
+
|
216
|
+
callbacks = model.callbacks.nil? ? [] : model.callbacks.dup
|
217
|
+
callbacks << ::XGBoost::EvaluationMonitor.new(period: 1)
|
218
|
+
|
219
|
+
# Generate batches without loading full dataset
|
220
|
+
batches = dataset.train(split_ys: true, batch_size: batch_size, batch_start: batch_start, batch_key: batch_key)
|
221
|
+
prev_xs = []
|
222
|
+
prev_ys = []
|
223
|
+
|
224
|
+
while current_iteration < num_iterations
|
225
|
+
# Load the next batch
|
226
|
+
x_train, y_train = batches.next
|
227
|
+
|
228
|
+
# Add batch_overlap from previous batch if applicable
|
229
|
+
merged_x, merged_y = nil, nil
|
230
|
+
if prev_xs.any?
|
231
|
+
merged_x = Polars.concat([x_train] + prev_xs.flatten)
|
232
|
+
merged_y = Polars.concat([y_train] + prev_ys.flatten)
|
233
|
+
end
|
234
|
+
|
235
|
+
if batch_overlap > 0
|
236
|
+
prev_xs << [x_train]
|
237
|
+
prev_ys << [y_train]
|
238
|
+
if prev_xs.size > batch_overlap
|
239
|
+
prev_xs = prev_xs[1..]
|
240
|
+
prev_ys = prev_ys[1..]
|
241
|
+
end
|
242
|
+
end
|
243
|
+
|
244
|
+
if merged_x.present?
|
245
|
+
x_train = merged_x
|
246
|
+
y_train = merged_y
|
247
|
+
end
|
248
|
+
|
249
|
+
d_train = preprocess(x_train, y_train)
|
250
|
+
evals = [[d_train, "train"], [d_valid, "eval"]]
|
251
|
+
|
252
|
+
model_file = current_batch == 0 ? nil : checkpoint_dir.join("#{current_batch - 1}.json").to_s
|
253
|
+
|
254
|
+
@booster = booster_class.new(
|
255
|
+
params: hyperparameters.to_h.symbolize_keys,
|
256
|
+
cache: [d_train, d_valid],
|
257
|
+
model_file: model_file,
|
258
|
+
)
|
259
|
+
loop_callbacks = callbacks.dup
|
260
|
+
if early_stopping_rounds
|
261
|
+
loop_callbacks << ::XGBoost::EarlyStopping.new(rounds: early_stopping_rounds)
|
262
|
+
end
|
263
|
+
cb_container = ::XGBoost::CallbackContainer.new(loop_callbacks)
|
264
|
+
@booster = cb_container.before_training(@booster) if current_iteration == 0
|
265
|
+
|
266
|
+
stopping_point = stopping_points[current_batch]
|
267
|
+
while current_iteration < stopping_point
|
268
|
+
break if cb_container.before_iteration(@booster, current_iteration, d_train, evals)
|
269
|
+
@booster.update(d_train, current_iteration)
|
270
|
+
break if cb_container.after_iteration(@booster, current_iteration, d_train, evals)
|
271
|
+
current_iteration += 1
|
272
|
+
end
|
273
|
+
current_iteration = stopping_point # In case of early stopping
|
274
|
+
|
275
|
+
@booster.save_model(checkpoint_dir.join("#{current_batch}.json").to_s)
|
276
|
+
current_batch += 1
|
277
|
+
end
|
278
|
+
|
279
|
+
@booster = cb_container.after_training(@booster)
|
280
|
+
delete_wandb_project unless tuning
|
281
|
+
return @booster
|
282
|
+
end
|
283
|
+
|
284
|
+
def weights
|
285
|
+
@booster.save_model("tmp/xgboost_model.json")
|
286
|
+
@booster.get_dump
|
287
|
+
end
|
288
|
+
|
289
|
+
def predict(xs)
|
290
|
+
raise "No trained model! Train a model before calling predict" unless @booster.present?
|
291
|
+
raise "Cannot predict on nil — XGBoost" if xs.nil?
|
292
|
+
|
293
|
+
begin
|
294
|
+
y_pred = @booster.predict(preprocess(xs))
|
295
|
+
rescue StandardError => e
|
296
|
+
raise e unless e.message.match?(/Number of columns does not match/)
|
297
|
+
|
298
|
+
raise %(
|
299
|
+
>>>>><<<<<
|
300
|
+
XGBoost received predict with unexpected features!
|
301
|
+
>>>>><<<<<
|
302
|
+
|
303
|
+
Model expects features:
|
304
|
+
#{feature_names}
|
305
|
+
Model received features:
|
306
|
+
#{xs.columns}
|
307
|
+
)
|
308
|
+
end
|
309
|
+
|
310
|
+
case task.to_sym
|
311
|
+
when :classification
|
312
|
+
to_classification(y_pred)
|
313
|
+
else
|
314
|
+
y_pred
|
315
|
+
end
|
316
|
+
end
|
317
|
+
|
318
|
+
def predict_proba(data)
|
319
|
+
dmat = DMatrix.new(data)
|
320
|
+
y_pred = @booster.predict(dmat)
|
321
|
+
|
322
|
+
if y_pred.first.is_a?(Array)
|
323
|
+
# multiple classes
|
324
|
+
y_pred
|
325
|
+
else
|
326
|
+
y_pred.map { |v| [1 - v, v] }
|
327
|
+
end
|
328
|
+
end
|
329
|
+
|
330
|
+
def unload
|
331
|
+
@xgboost_model = nil
|
332
|
+
@booster = nil
|
333
|
+
end
|
334
|
+
|
335
|
+
def loaded?
|
336
|
+
@booster.present? && @booster.feature_names.any?
|
337
|
+
end
|
338
|
+
|
339
|
+
def load_model_file(path)
|
340
|
+
return if loaded?
|
341
|
+
|
342
|
+
initialize_model do
|
343
|
+
attrs = {
|
344
|
+
params: hyperparameters.to_h.symbolize_keys.compact,
|
345
|
+
model_file: path,
|
346
|
+
}.compact
|
347
|
+
booster_class.new(**attrs)
|
348
|
+
end
|
349
|
+
end
|
350
|
+
|
351
|
+
def external_model
|
352
|
+
@booster
|
353
|
+
end
|
354
|
+
|
355
|
+
def external_model=(booster)
|
356
|
+
@booster = booster
|
357
|
+
end
|
358
|
+
|
359
|
+
def model_changed?(prev_hash)
|
360
|
+
return false unless @booster.present? && @booster.feature_names.any?
|
361
|
+
|
362
|
+
current_model_hash = nil
|
363
|
+
Tempfile.create(["xgboost_model", ".json"]) do |tempfile|
|
364
|
+
@booster.save_model(tempfile.path)
|
365
|
+
tempfile.rewind
|
366
|
+
JSON.parse(tempfile.read)
|
367
|
+
current_model_hash = Digest::SHA256.file(tempfile.path).hexdigest
|
368
|
+
end
|
369
|
+
current_model_hash != prev_hash
|
370
|
+
end
|
371
|
+
|
372
|
+
def save_model_file(path)
|
373
|
+
path = path.to_s
|
374
|
+
ensure_directory_exists(File.dirname(path))
|
375
|
+
extension = Pathname.new(path).extname.gsub("\.", "")
|
376
|
+
path = "#{path}.json" unless extension == "json"
|
377
|
+
|
378
|
+
@booster.save_model(path)
|
379
|
+
path
|
380
|
+
end
|
381
|
+
|
382
|
+
def feature_names
|
383
|
+
@booster.feature_names
|
384
|
+
end
|
385
|
+
|
386
|
+
def feature_importances
|
387
|
+
score = @booster.score(importance_type: @importance_type || "gain")
|
388
|
+
scores = @booster.feature_names.map { |k| score[k] || 0.0 }
|
389
|
+
total = scores.sum.to_f
|
390
|
+
fi = scores.map { |s| s / total }
|
391
|
+
@booster.feature_names.zip(fi).to_h
|
392
|
+
end
|
393
|
+
|
394
|
+
def base_model
|
395
|
+
::XGBoost
|
396
|
+
end
|
397
|
+
|
398
|
+
def prepare_data
|
399
|
+
if @d_train.nil?
|
400
|
+
x_sample, y_sample = dataset.train(split_ys: true, limit: 5)
|
401
|
+
preprocess(x_sample, y_sample) # Ensure we fail fast if the dataset is misconfigured
|
402
|
+
x_train, y_train = dataset.train(split_ys: true)
|
403
|
+
x_valid, y_valid = dataset.valid(split_ys: true)
|
404
|
+
x_test, y_test = dataset.test(split_ys: true)
|
405
|
+
@d_train = preprocess(x_train, y_train)
|
406
|
+
@d_valid = preprocess(x_valid, y_valid)
|
407
|
+
@d_test = preprocess(x_test, y_test)
|
408
|
+
end
|
409
|
+
|
410
|
+
[@d_train, @d_valid, @d_test]
|
411
|
+
end
|
412
|
+
|
413
|
+
def preprocess(xs, ys = nil)
|
414
|
+
return xs if xs.is_a?(::XGBoost::DMatrix)
|
415
|
+
|
416
|
+
orig_xs = xs.dup
|
417
|
+
column_names = xs.columns
|
418
|
+
xs = _preprocess(xs)
|
419
|
+
ys = ys.nil? ? nil : _preprocess(ys).flatten
|
420
|
+
kwargs = { label: ys }.compact
|
421
|
+
begin
|
422
|
+
::XGBoost::DMatrix.new(xs, **kwargs).tap do |dmat|
|
423
|
+
dmat.feature_names = column_names
|
424
|
+
end
|
425
|
+
rescue StandardError => e
|
426
|
+
problematic_columns = orig_xs.schema.select { |k, v| [Polars::Categorical, Polars::String].include?(v) }
|
427
|
+
problematic_xs = orig_xs.select(problematic_columns.keys)
|
428
|
+
raise %(
|
429
|
+
Error building data for XGBoost.
|
430
|
+
Apply preprocessing to columns
|
431
|
+
>>>>><<<<<
|
432
|
+
#{problematic_columns.keys}
|
433
|
+
>>>>><<<<<
|
434
|
+
A sample of your dataset:
|
435
|
+
#{problematic_xs[0..5]}
|
436
|
+
|
437
|
+
#{if ys.present?
|
438
|
+
%(
|
439
|
+
This may also be due to your targets:
|
440
|
+
#{ys[0..5]}
|
441
|
+
)
|
442
|
+
else
|
443
|
+
""
|
444
|
+
end}
|
445
|
+
)
|
446
|
+
end
|
447
|
+
end
|
448
|
+
|
449
|
+
def self.hyperparameter_constants
|
450
|
+
EasyML::Models::Hyperparameters::XGBoost.hyperparameter_constants
|
451
|
+
end
|
452
|
+
|
453
|
+
private
|
454
|
+
|
455
|
+
def booster_class
|
456
|
+
::XGBoost::Booster
|
457
|
+
end
|
458
|
+
|
459
|
+
def d_matrix_class
|
460
|
+
::XGBoost::DMatrix
|
461
|
+
end
|
462
|
+
|
463
|
+
def model_class
|
464
|
+
::XGBoost::Model
|
465
|
+
end
|
466
|
+
|
467
|
+
def fit_batch(d_train, current_iteration, evals, cb_container)
|
468
|
+
if @booster.nil?
|
469
|
+
@booster = booster_class.new(params: @hyperparameters.to_h, cache: [d_train] + evals.map do |d|
|
470
|
+
d[0]
|
471
|
+
end, early_stopping_rounds: @hyperparameters.to_h.dig(:early_stopping_rounds))
|
472
|
+
end
|
473
|
+
|
474
|
+
@booster = cb_container.before_training(@booster)
|
475
|
+
cb_container.before_iteration(@booster, current_iteration, d_train, evals)
|
476
|
+
@booster.update(d_train, current_iteration)
|
477
|
+
cb_container.after_iteration(@booster, current_iteration, d_train, evals)
|
478
|
+
end
|
479
|
+
|
480
|
+
def _preprocess(df)
|
481
|
+
return df if df.is_a?(Array)
|
482
|
+
|
483
|
+
df.to_a.map do |row|
|
484
|
+
row.values.map do |value|
|
485
|
+
case value
|
486
|
+
when Time
|
487
|
+
value.to_i # Convert Time to Unix timestamp
|
488
|
+
when Date
|
489
|
+
value.to_time.to_i # Convert Date to Unix timestamp
|
490
|
+
when String
|
491
|
+
value
|
492
|
+
when TrueClass, FalseClass
|
493
|
+
value ? 1.0 : 0.0 # Convert booleans to 1.0 and 0.0
|
494
|
+
when Integer
|
495
|
+
value
|
496
|
+
else
|
497
|
+
value.to_f # Ensure everything else is converted to a float
|
498
|
+
end
|
499
|
+
end
|
500
|
+
end
|
501
|
+
end
|
502
|
+
|
503
|
+
def initialize_model
|
504
|
+
@xgboost_model = model_class.new(n_estimators: @hyperparameters.to_h.dig(:n_estimators))
|
505
|
+
if block_given?
|
506
|
+
@booster = yield
|
507
|
+
else
|
508
|
+
attrs = {
|
509
|
+
params: hyperparameters.to_h.symbolize_keys,
|
510
|
+
}.deep_compact
|
511
|
+
@booster = booster_class.new(**attrs)
|
512
|
+
end
|
513
|
+
@xgboost_model.instance_variable_set(:@booster, @booster)
|
514
|
+
end
|
515
|
+
|
516
|
+
def validate_objective
|
517
|
+
objective = hyperparameters.objective
|
518
|
+
unless task.present?
|
519
|
+
raise ArgumentError,
|
520
|
+
"cannot train model without task. Please specify either regression or classification (model.task = :regression)"
|
521
|
+
end
|
522
|
+
|
523
|
+
case task.to_sym
|
524
|
+
when :classification
|
525
|
+
_, ys = dataset.data(split_ys: true)
|
526
|
+
classification_type = ys[ys.columns.first].uniq.count <= 2 ? :binary : :multi_class
|
527
|
+
allowed_objectives = OBJECTIVES[:classification][classification_type]
|
528
|
+
else
|
529
|
+
allowed_objectives = OBJECTIVES[task.to_sym]
|
530
|
+
end
|
531
|
+
return if allowed_objectives.map(&:to_sym).include?(objective.to_sym)
|
532
|
+
|
533
|
+
raise ArgumentError,
|
534
|
+
"cannot use #{objective} for #{task} task. Allowed objectives are: #{allowed_objectives.join(", ")}"
|
535
|
+
end
|
536
|
+
|
537
|
+
def to_classification(y_pred)
|
538
|
+
if y_pred.first.is_a?(Array)
|
539
|
+
# multiple classes
|
540
|
+
y_pred.map do |v|
|
541
|
+
v.map.with_index.max_by { |v2, _| v2 }.last
|
542
|
+
end
|
543
|
+
else
|
544
|
+
y_pred.map { |v| v > 0.5 ? 1 : 0 }
|
545
|
+
end
|
546
|
+
end
|
8
547
|
end
|
9
548
|
end
|
10
549
|
end
|
@@ -0,0 +1,44 @@
|
|
1
|
+
# == Schema Information
|
2
|
+
#
|
3
|
+
# Table name: easy_ml_predictions
|
4
|
+
#
|
5
|
+
# id :bigint not null, primary key
|
6
|
+
# model_id :bigint not null
|
7
|
+
# model_history_id :bigint
|
8
|
+
# prediction_type :string
|
9
|
+
# prediction_value :jsonb
|
10
|
+
# raw_input :jsonb
|
11
|
+
# normalized_input :jsonb
|
12
|
+
# created_at :datetime not null
|
13
|
+
# updated_at :datetime not null
|
14
|
+
#
|
15
|
+
module EasyML
|
16
|
+
class Prediction < ActiveRecord::Base
|
17
|
+
self.table_name = "easy_ml_predictions"
|
18
|
+
|
19
|
+
belongs_to :model
|
20
|
+
belongs_to :model_history, optional: true
|
21
|
+
|
22
|
+
validates :model_id, presence: true
|
23
|
+
validates :prediction_type, presence: true, inclusion: { in: %w[regression classification] }
|
24
|
+
validates :prediction_value, presence: true
|
25
|
+
validates :raw_input, presence: true
|
26
|
+
validates :normalized_input, presence: true
|
27
|
+
|
28
|
+
def prediction
|
29
|
+
prediction_value["value"]
|
30
|
+
end
|
31
|
+
|
32
|
+
def probabilities
|
33
|
+
prediction_value["probabilities"]
|
34
|
+
end
|
35
|
+
|
36
|
+
def regression?
|
37
|
+
prediction_type == "regression"
|
38
|
+
end
|
39
|
+
|
40
|
+
def classification?
|
41
|
+
prediction_type == "classification"
|
42
|
+
end
|
43
|
+
end
|
44
|
+
end
|