easy_ml 0.1.4 → 0.2.0.pre.rc1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/README.md +234 -26
- data/Rakefile +45 -0
- data/app/controllers/easy_ml/application_controller.rb +67 -0
- data/app/controllers/easy_ml/columns_controller.rb +38 -0
- data/app/controllers/easy_ml/datasets_controller.rb +156 -0
- data/app/controllers/easy_ml/datasources_controller.rb +88 -0
- data/app/controllers/easy_ml/deploys_controller.rb +20 -0
- data/app/controllers/easy_ml/models_controller.rb +151 -0
- data/app/controllers/easy_ml/retraining_runs_controller.rb +19 -0
- data/app/controllers/easy_ml/settings_controller.rb +59 -0
- data/app/frontend/components/AlertProvider.tsx +108 -0
- data/app/frontend/components/DatasetPreview.tsx +161 -0
- data/app/frontend/components/EmptyState.tsx +28 -0
- data/app/frontend/components/ModelCard.tsx +255 -0
- data/app/frontend/components/ModelDetails.tsx +334 -0
- data/app/frontend/components/ModelForm.tsx +384 -0
- data/app/frontend/components/Navigation.tsx +300 -0
- data/app/frontend/components/Pagination.tsx +72 -0
- data/app/frontend/components/Popover.tsx +55 -0
- data/app/frontend/components/PredictionStream.tsx +105 -0
- data/app/frontend/components/ScheduleModal.tsx +726 -0
- data/app/frontend/components/SearchInput.tsx +23 -0
- data/app/frontend/components/SearchableSelect.tsx +132 -0
- data/app/frontend/components/dataset/AutosaveIndicator.tsx +39 -0
- data/app/frontend/components/dataset/ColumnConfigModal.tsx +431 -0
- data/app/frontend/components/dataset/ColumnFilters.tsx +256 -0
- data/app/frontend/components/dataset/ColumnList.tsx +101 -0
- data/app/frontend/components/dataset/FeatureConfigPopover.tsx +57 -0
- data/app/frontend/components/dataset/FeaturePicker.tsx +205 -0
- data/app/frontend/components/dataset/PreprocessingConfig.tsx +704 -0
- data/app/frontend/components/dataset/SplitConfigurator.tsx +120 -0
- data/app/frontend/components/dataset/splitters/DateSplitter.tsx +58 -0
- data/app/frontend/components/dataset/splitters/KFoldSplitter.tsx +68 -0
- data/app/frontend/components/dataset/splitters/LeavePOutSplitter.tsx +29 -0
- data/app/frontend/components/dataset/splitters/PredefinedSplitter.tsx +146 -0
- data/app/frontend/components/dataset/splitters/RandomSplitter.tsx +85 -0
- data/app/frontend/components/dataset/splitters/StratifiedSplitter.tsx +79 -0
- data/app/frontend/components/dataset/splitters/constants.ts +77 -0
- data/app/frontend/components/dataset/splitters/types.ts +168 -0
- data/app/frontend/components/dataset/splitters/utils.ts +53 -0
- data/app/frontend/components/features/CodeEditor.tsx +46 -0
- data/app/frontend/components/features/DataPreview.tsx +150 -0
- data/app/frontend/components/features/FeatureCard.tsx +88 -0
- data/app/frontend/components/features/FeatureForm.tsx +235 -0
- data/app/frontend/components/features/FeatureGroupCard.tsx +54 -0
- data/app/frontend/components/settings/PluginSettings.tsx +81 -0
- data/app/frontend/components/ui/badge.tsx +44 -0
- data/app/frontend/components/ui/collapsible.tsx +9 -0
- data/app/frontend/components/ui/scroll-area.tsx +46 -0
- data/app/frontend/components/ui/separator.tsx +29 -0
- data/app/frontend/entrypoints/App.tsx +40 -0
- data/app/frontend/entrypoints/Application.tsx +24 -0
- data/app/frontend/hooks/useAutosave.ts +61 -0
- data/app/frontend/layouts/Layout.tsx +38 -0
- data/app/frontend/lib/utils.ts +6 -0
- data/app/frontend/mockData.ts +272 -0
- data/app/frontend/pages/DatasetDetailsPage.tsx +103 -0
- data/app/frontend/pages/DatasetsPage.tsx +261 -0
- data/app/frontend/pages/DatasourceFormPage.tsx +147 -0
- data/app/frontend/pages/DatasourcesPage.tsx +261 -0
- data/app/frontend/pages/EditModelPage.tsx +45 -0
- data/app/frontend/pages/EditTransformationPage.tsx +56 -0
- data/app/frontend/pages/ModelsPage.tsx +115 -0
- data/app/frontend/pages/NewDatasetPage.tsx +366 -0
- data/app/frontend/pages/NewModelPage.tsx +45 -0
- data/app/frontend/pages/NewTransformationPage.tsx +43 -0
- data/app/frontend/pages/SettingsPage.tsx +272 -0
- data/app/frontend/pages/ShowModelPage.tsx +30 -0
- data/app/frontend/pages/TransformationsPage.tsx +95 -0
- data/app/frontend/styles/application.css +100 -0
- data/app/frontend/types/dataset.ts +146 -0
- data/app/frontend/types/datasource.ts +33 -0
- data/app/frontend/types/preprocessing.ts +1 -0
- data/app/frontend/types.ts +113 -0
- data/app/helpers/easy_ml/application_helper.rb +10 -0
- data/app/jobs/easy_ml/application_job.rb +21 -0
- data/app/jobs/easy_ml/batch_job.rb +46 -0
- data/app/jobs/easy_ml/compute_feature_job.rb +19 -0
- data/app/jobs/easy_ml/deploy_job.rb +13 -0
- data/app/jobs/easy_ml/finalize_feature_job.rb +15 -0
- data/app/jobs/easy_ml/refresh_dataset_job.rb +32 -0
- data/app/jobs/easy_ml/schedule_retraining_job.rb +11 -0
- data/app/jobs/easy_ml/sync_datasource_job.rb +17 -0
- data/app/jobs/easy_ml/training_job.rb +62 -0
- data/app/models/easy_ml/adapters/base_adapter.rb +45 -0
- data/app/models/easy_ml/adapters/polars_adapter.rb +77 -0
- data/app/models/easy_ml/cleaner.rb +82 -0
- data/app/models/easy_ml/column.rb +124 -0
- data/app/models/easy_ml/column_history.rb +30 -0
- data/app/models/easy_ml/column_list.rb +122 -0
- data/app/models/easy_ml/concerns/configurable.rb +61 -0
- data/app/models/easy_ml/concerns/versionable.rb +19 -0
- data/app/models/easy_ml/dataset.rb +767 -0
- data/app/models/easy_ml/dataset_history.rb +56 -0
- data/app/models/easy_ml/datasource.rb +182 -0
- data/app/models/easy_ml/datasource_history.rb +24 -0
- data/app/models/easy_ml/datasources/base_datasource.rb +54 -0
- data/app/models/easy_ml/datasources/file_datasource.rb +58 -0
- data/app/models/easy_ml/datasources/polars_datasource.rb +89 -0
- data/app/models/easy_ml/datasources/s3_datasource.rb +97 -0
- data/app/models/easy_ml/deploy.rb +114 -0
- data/app/models/easy_ml/event.rb +79 -0
- data/app/models/easy_ml/feature.rb +437 -0
- data/app/models/easy_ml/feature_history.rb +38 -0
- data/app/models/easy_ml/model.rb +575 -41
- data/app/models/easy_ml/model_file.rb +133 -0
- data/app/models/easy_ml/model_file_history.rb +24 -0
- data/app/models/easy_ml/model_history.rb +51 -0
- data/app/models/easy_ml/models/base_model.rb +58 -0
- data/app/models/easy_ml/models/hyperparameters/base.rb +99 -0
- data/app/models/easy_ml/models/hyperparameters/xgboost/dart.rb +82 -0
- data/app/models/easy_ml/models/hyperparameters/xgboost/gblinear.rb +82 -0
- data/app/models/easy_ml/models/hyperparameters/xgboost/gbtree.rb +97 -0
- data/app/models/easy_ml/models/hyperparameters/xgboost.rb +71 -0
- data/app/models/easy_ml/models/xgboost/evals_callback.rb +138 -0
- data/app/models/easy_ml/models/xgboost/progress_callback.rb +39 -0
- data/app/models/easy_ml/models/xgboost.rb +544 -5
- data/app/models/easy_ml/prediction.rb +44 -0
- data/app/models/easy_ml/retraining_job.rb +278 -0
- data/app/models/easy_ml/retraining_run.rb +184 -0
- data/app/models/easy_ml/settings.rb +37 -0
- data/app/models/easy_ml/splitter.rb +90 -0
- data/app/models/easy_ml/splitters/base_splitter.rb +28 -0
- data/app/models/easy_ml/splitters/date_splitter.rb +91 -0
- data/app/models/easy_ml/splitters/predefined_splitter.rb +74 -0
- data/app/models/easy_ml/splitters/random_splitter.rb +82 -0
- data/app/models/easy_ml/tuner_job.rb +56 -0
- data/app/models/easy_ml/tuner_run.rb +31 -0
- data/app/models/splitter_history.rb +6 -0
- data/app/serializers/easy_ml/column_serializer.rb +27 -0
- data/app/serializers/easy_ml/dataset_serializer.rb +73 -0
- data/app/serializers/easy_ml/datasource_serializer.rb +64 -0
- data/app/serializers/easy_ml/feature_serializer.rb +27 -0
- data/app/serializers/easy_ml/model_serializer.rb +90 -0
- data/app/serializers/easy_ml/retraining_job_serializer.rb +22 -0
- data/app/serializers/easy_ml/retraining_run_serializer.rb +39 -0
- data/app/serializers/easy_ml/settings_serializer.rb +9 -0
- data/app/views/layouts/easy_ml/application.html.erb +15 -0
- data/config/initializers/resque.rb +3 -0
- data/config/resque-pool.yml +6 -0
- data/config/routes.rb +39 -0
- data/config/spring.rb +1 -0
- data/config/vite.json +15 -0
- data/lib/easy_ml/configuration.rb +64 -0
- data/lib/easy_ml/core/evaluators/base_evaluator.rb +53 -0
- data/lib/easy_ml/core/evaluators/classification_evaluators.rb +126 -0
- data/lib/easy_ml/core/evaluators/regression_evaluators.rb +66 -0
- data/lib/easy_ml/core/model_evaluator.rb +161 -89
- data/lib/easy_ml/core/tuner/adapters/base_adapter.rb +28 -18
- data/lib/easy_ml/core/tuner/adapters/xgboost_adapter.rb +4 -25
- data/lib/easy_ml/core/tuner.rb +123 -62
- data/lib/easy_ml/core.rb +0 -3
- data/lib/easy_ml/core_ext/hash.rb +24 -0
- data/lib/easy_ml/core_ext/pathname.rb +11 -5
- data/lib/easy_ml/data/date_converter.rb +90 -0
- data/lib/easy_ml/data/filter_extensions.rb +31 -0
- data/lib/easy_ml/data/polars_column.rb +126 -0
- data/lib/easy_ml/data/polars_reader.rb +297 -0
- data/lib/easy_ml/data/preprocessor.rb +280 -142
- data/lib/easy_ml/data/simple_imputer.rb +255 -0
- data/lib/easy_ml/data/splits/file_split.rb +252 -0
- data/lib/easy_ml/data/splits/in_memory_split.rb +54 -0
- data/lib/easy_ml/data/splits/split.rb +95 -0
- data/lib/easy_ml/data/splits.rb +9 -0
- data/lib/easy_ml/data/statistics_learner.rb +93 -0
- data/lib/easy_ml/data/synced_directory.rb +341 -0
- data/lib/easy_ml/data.rb +6 -2
- data/lib/easy_ml/engine.rb +105 -6
- data/lib/easy_ml/feature_store.rb +227 -0
- data/lib/easy_ml/features.rb +61 -0
- data/lib/easy_ml/initializers/inflections.rb +17 -3
- data/lib/easy_ml/logging.rb +2 -2
- data/lib/easy_ml/predict.rb +74 -0
- data/lib/easy_ml/railtie/generators/migration/migration_generator.rb +192 -36
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_column_histories.rb.tt +9 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_columns.rb.tt +25 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_dataset_histories.rb.tt +9 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_datasets.rb.tt +31 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_datasource_histories.rb.tt +9 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_datasources.rb.tt +16 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_deploys.rb.tt +24 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_events.rb.tt +20 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_feature_histories.rb.tt +14 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_features.rb.tt +32 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_model_file_histories.rb.tt +9 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_model_files.rb.tt +17 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_model_histories.rb.tt +9 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_models.rb.tt +20 -9
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_predictions.rb.tt +17 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_retraining_jobs.rb.tt +77 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_settings.rb.tt +9 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_splitter_histories.rb.tt +9 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_splitters.rb.tt +15 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_tuner_jobs.rb.tt +40 -0
- data/lib/easy_ml/support/est.rb +5 -1
- data/lib/easy_ml/support/file_rotate.rb +79 -15
- data/lib/easy_ml/support/file_support.rb +9 -0
- data/lib/easy_ml/support/local_file.rb +24 -0
- data/lib/easy_ml/support/lockable.rb +62 -0
- data/lib/easy_ml/support/synced_file.rb +103 -0
- data/lib/easy_ml/support/utc.rb +5 -1
- data/lib/easy_ml/support.rb +6 -3
- data/lib/easy_ml/version.rb +4 -1
- data/lib/easy_ml.rb +7 -2
- metadata +355 -72
- data/app/models/easy_ml/models.rb +0 -5
- data/lib/easy_ml/core/model.rb +0 -30
- data/lib/easy_ml/core/model_core.rb +0 -181
- data/lib/easy_ml/core/models/hyperparameters/base.rb +0 -34
- data/lib/easy_ml/core/models/hyperparameters/xgboost.rb +0 -19
- data/lib/easy_ml/core/models/xgboost.rb +0 -10
- data/lib/easy_ml/core/models/xgboost_core.rb +0 -220
- data/lib/easy_ml/core/models.rb +0 -10
- data/lib/easy_ml/core/uploaders/model_uploader.rb +0 -24
- data/lib/easy_ml/core/uploaders.rb +0 -7
- data/lib/easy_ml/data/dataloader.rb +0 -6
- data/lib/easy_ml/data/dataset/data/preprocessor/statistics.json +0 -31
- data/lib/easy_ml/data/dataset/data/sample_info.json +0 -1
- data/lib/easy_ml/data/dataset/dataset/files/sample_info.json +0 -1
- data/lib/easy_ml/data/dataset/splits/file_split.rb +0 -140
- data/lib/easy_ml/data/dataset/splits/in_memory_split.rb +0 -49
- data/lib/easy_ml/data/dataset/splits/split.rb +0 -98
- data/lib/easy_ml/data/dataset/splits.rb +0 -11
- data/lib/easy_ml/data/dataset/splitters/date_splitter.rb +0 -43
- data/lib/easy_ml/data/dataset/splitters.rb +0 -9
- data/lib/easy_ml/data/dataset.rb +0 -430
- data/lib/easy_ml/data/datasource/datasource_factory.rb +0 -60
- data/lib/easy_ml/data/datasource/file_datasource.rb +0 -40
- data/lib/easy_ml/data/datasource/merged_datasource.rb +0 -64
- data/lib/easy_ml/data/datasource/polars_datasource.rb +0 -41
- data/lib/easy_ml/data/datasource/s3_datasource.rb +0 -89
- data/lib/easy_ml/data/datasource.rb +0 -33
- data/lib/easy_ml/data/preprocessor/preprocessor.rb +0 -205
- data/lib/easy_ml/data/preprocessor/simple_imputer.rb +0 -402
- data/lib/easy_ml/deployment.rb +0 -5
- data/lib/easy_ml/support/synced_directory.rb +0 -134
- data/lib/easy_ml/transforms.rb +0 -29
- /data/{lib/easy_ml/core → app/models/easy_ml}/models/hyperparameters.rb +0 -0
data/app/models/easy_ml/model.rb
CHANGED
@@ -1,68 +1,602 @@
|
|
1
|
-
|
1
|
+
# == Schema Information
|
2
|
+
#
|
3
|
+
# Table name: easy_ml_models
|
4
|
+
#
|
5
|
+
# id :bigint not null, primary key
|
6
|
+
# name :string not null
|
7
|
+
# model_type :string
|
8
|
+
# status :string
|
9
|
+
# dataset_id :bigint
|
10
|
+
# model_file_id :bigint
|
11
|
+
# configuration :json
|
12
|
+
# version :string not null
|
13
|
+
# root_dir :string
|
14
|
+
# file :json
|
15
|
+
# sha :string
|
16
|
+
# last_trained_at :datetime
|
17
|
+
# is_training :boolean
|
18
|
+
# created_at :datetime not null
|
19
|
+
# updated_at :datetime not null
|
20
|
+
#
|
21
|
+
require_relative "models/hyperparameters"
|
22
|
+
|
2
23
|
module EasyML
|
3
24
|
class Model < ActiveRecord::Base
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
25
|
+
self.table_name = "easy_ml_models"
|
26
|
+
include Historiographer::Silent
|
27
|
+
historiographer_mode :snapshot_only
|
28
|
+
|
29
|
+
include EasyML::Concerns::Configurable
|
30
|
+
include EasyML::Concerns::Versionable
|
31
|
+
|
32
|
+
self.filter_attributes += [:configuration]
|
33
|
+
|
34
|
+
MODEL_OPTIONS = {
|
35
|
+
"xgboost" => "EasyML::Models::XGBoost",
|
36
|
+
}
|
37
|
+
MODEL_TYPES = [
|
38
|
+
{
|
39
|
+
value: "xgboost",
|
40
|
+
label: "XGBoost",
|
41
|
+
description: "Extreme Gradient Boosting, a scalable and accurate implementation of gradient boosting machines",
|
42
|
+
},
|
43
|
+
].freeze
|
44
|
+
MODEL_NAMES = MODEL_OPTIONS.keys.freeze
|
45
|
+
MODEL_CONSTANTS = MODEL_OPTIONS.values.map(&:constantize)
|
46
|
+
|
47
|
+
add_configuration_attributes :task, :objective, :hyperparameters, :evaluator, :callbacks, :metrics
|
48
|
+
MODEL_CONSTANTS.flat_map(&:configuration_attributes).each do |attribute|
|
49
|
+
add_configuration_attributes attribute
|
50
|
+
end
|
51
|
+
|
52
|
+
belongs_to :dataset
|
53
|
+
belongs_to :model_file, class_name: "EasyML::ModelFile", foreign_key: "model_file_id", optional: true
|
54
|
+
|
55
|
+
has_one :retraining_job, class_name: "EasyML::RetrainingJob"
|
56
|
+
accepts_nested_attributes_for :retraining_job
|
57
|
+
has_many :retraining_runs, class_name: "EasyML::RetrainingRun"
|
58
|
+
has_many :deploys, class_name: "EasyML::Deploy"
|
59
|
+
|
60
|
+
scope :deployed, -> { EasyML::ModelHistory.deployed }
|
61
|
+
|
62
|
+
def latest_deploy
|
63
|
+
deploys.order(id: :desc).limit(1).last
|
64
|
+
end
|
65
|
+
|
66
|
+
after_initialize :bump_version, if: -> { new_record? }
|
67
|
+
after_initialize :set_defaults, if: -> { new_record? }
|
68
|
+
before_save :save_model_file, if: -> { is_fit? && !is_history_class? && model_changed? && !@skip_save_model_file }
|
69
|
+
|
70
|
+
VALID_TASKS = %i[regression classification].freeze
|
71
|
+
|
72
|
+
TASK_TYPES = [
|
73
|
+
{
|
74
|
+
value: "classification",
|
75
|
+
label: "Classification",
|
76
|
+
description: "Predict categorical outcomes or class labels",
|
77
|
+
},
|
78
|
+
{
|
79
|
+
value: "regression",
|
80
|
+
label: "Regression",
|
81
|
+
description: "Predict continuous numerical values",
|
82
|
+
},
|
83
|
+
].freeze
|
84
|
+
|
85
|
+
validates :name, presence: true
|
86
|
+
validates :name, uniqueness: { case_sensitive: false }
|
87
|
+
validates :task, presence: true
|
88
|
+
validates :task, inclusion: {
|
89
|
+
in: VALID_TASKS.map { |t| [t, t.to_s] }.flatten,
|
90
|
+
message: "must be one of: #{VALID_TASKS.join(", ")}",
|
91
|
+
}
|
92
|
+
validates :model_type, inclusion: { in: MODEL_NAMES }
|
93
|
+
validates :dataset_id, presence: true
|
94
|
+
validate :validate_metrics_allowed
|
95
|
+
before_save :set_root_dir
|
96
|
+
|
97
|
+
delegate :prepare_data, :preprocess, to: :adapter
|
98
|
+
|
99
|
+
STATUSES = %w[development inference retired]
|
100
|
+
STATUSES.each do |status|
|
101
|
+
define_method "#{status}?" do
|
102
|
+
self.status.to_sym == status.to_sym
|
12
103
|
end
|
104
|
+
end
|
13
105
|
|
14
|
-
|
106
|
+
def training?
|
107
|
+
is_training == true
|
15
108
|
end
|
16
109
|
|
17
|
-
|
18
|
-
|
19
|
-
|
110
|
+
def train(async: true)
|
111
|
+
pending_run # Ensure we update the pending job before enqueuing in background so UI updates properly
|
112
|
+
update(is_training: true)
|
113
|
+
if async
|
114
|
+
EasyML::TrainingJob.perform_later(id)
|
115
|
+
else
|
116
|
+
actually_train
|
117
|
+
end
|
118
|
+
end
|
20
119
|
|
21
|
-
|
22
|
-
|
23
|
-
|
120
|
+
def get_retraining_job
|
121
|
+
if retraining_job
|
122
|
+
self.evaluator = retraining_job.evaluator
|
123
|
+
evaluator = self.evaluator.symbolize_keys
|
124
|
+
else
|
125
|
+
default_eval = Core::ModelEvaluator.default_evaluator(task)
|
126
|
+
self.evaluator = default_eval
|
127
|
+
evaluator = default_eval
|
128
|
+
end
|
129
|
+
|
130
|
+
retraining_job || create_retraining_job(
|
131
|
+
model: self,
|
132
|
+
active: false,
|
133
|
+
evaluator: evaluator,
|
134
|
+
metric: evaluator[:metric],
|
135
|
+
direction: evaluator[:direction],
|
136
|
+
threshold: evaluator[:threshold],
|
137
|
+
frequency: "month",
|
138
|
+
at: { hour: 0, day_of_month: 1 },
|
139
|
+
)
|
140
|
+
end
|
24
141
|
|
25
|
-
|
26
|
-
|
142
|
+
def pending_run
|
143
|
+
job = get_retraining_job
|
144
|
+
job.retraining_runs.find_or_create_by(status: "pending", model: self)
|
145
|
+
end
|
146
|
+
|
147
|
+
def actually_train(&progress_block)
|
148
|
+
lock_model do
|
149
|
+
run = pending_run
|
150
|
+
run.wrap_training do
|
151
|
+
best_params = nil
|
152
|
+
if run.should_tune?
|
153
|
+
best_params = hyperparameter_search(&progress_block)
|
154
|
+
end
|
155
|
+
fit(&progress_block)
|
156
|
+
save
|
157
|
+
[self, best_params]
|
158
|
+
end
|
159
|
+
update(is_training: false)
|
160
|
+
run.reload
|
161
|
+
ensure
|
162
|
+
unlock!
|
163
|
+
end
|
164
|
+
end
|
165
|
+
|
166
|
+
def unlock!
|
167
|
+
Support::Lockable.unlock!(lock_key)
|
168
|
+
end
|
169
|
+
|
170
|
+
def lock_model
|
171
|
+
with_lock do |client|
|
172
|
+
yield
|
173
|
+
end
|
174
|
+
end
|
175
|
+
|
176
|
+
def with_lock
|
177
|
+
EasyML::Support::Lockable.with_lock(lock_key, stale_timeout: 60, resources: 1) do |client|
|
178
|
+
yield client
|
179
|
+
end
|
180
|
+
end
|
181
|
+
|
182
|
+
def lock_key
|
183
|
+
"training:#{self.name}:#{self.id}"
|
184
|
+
end
|
185
|
+
|
186
|
+
def hyperparameter_search(&progress_block)
|
187
|
+
tuner = retraining_job.tuner_config.symbolize_keys
|
188
|
+
extra_params = {
|
189
|
+
evaluator: evaluator,
|
190
|
+
model: self,
|
191
|
+
dataset: dataset,
|
192
|
+
}.compact
|
193
|
+
tuner.merge!(extra_params)
|
194
|
+
tuner_instance = EasyML::Core::Tuner.new(tuner)
|
195
|
+
tuner_instance.tune(&progress_block).tap do |best_params|
|
196
|
+
best_params.each do |key, value|
|
197
|
+
self.hyperparameters.send("#{key}=", value)
|
198
|
+
end
|
27
199
|
end
|
200
|
+
end
|
201
|
+
|
202
|
+
def deployment_status
|
203
|
+
status
|
204
|
+
end
|
205
|
+
|
206
|
+
def formatted_model_type
|
207
|
+
adapter.class.name.split("::").last
|
208
|
+
end
|
209
|
+
|
210
|
+
def formatted_version
|
211
|
+
return nil unless version
|
212
|
+
Time.strptime(version, "%Y%m%d%H%M%S").strftime("%B %-d, %Y at %-l:%M %p")
|
213
|
+
end
|
28
214
|
|
29
|
-
|
215
|
+
def last_run_at
|
216
|
+
last_run&.created_at
|
217
|
+
end
|
218
|
+
|
219
|
+
def last_run
|
220
|
+
retraining_runs.order(id: :desc).limit(1).last
|
221
|
+
end
|
222
|
+
|
223
|
+
def inference_version
|
224
|
+
latest_deploy&.model_version
|
225
|
+
end
|
226
|
+
|
227
|
+
alias_method :current_version, :inference_version
|
228
|
+
alias_method :latest_version, :inference_version
|
229
|
+
alias_method :deployed, :inference_version
|
30
230
|
|
31
|
-
|
32
|
-
|
231
|
+
def hyperparameters
|
232
|
+
@hypers ||= adapter.build_hyperparameters(@hyperparameters)
|
33
233
|
end
|
34
234
|
|
35
|
-
def
|
36
|
-
|
37
|
-
|
38
|
-
|
235
|
+
def callbacks
|
236
|
+
@cbs ||= adapter.build_callbacks(@callbacks)
|
237
|
+
end
|
238
|
+
|
239
|
+
def predict(xs)
|
240
|
+
load_model!
|
241
|
+
adapter.predict(xs)
|
242
|
+
end
|
243
|
+
|
244
|
+
def save_model_file
|
245
|
+
raise "No trained model! Need to train model before saving (call model.fit)" unless is_fit?
|
246
|
+
return unless adapter.loaded?
|
247
|
+
|
248
|
+
model_file = get_model_file
|
249
|
+
|
250
|
+
bump_version(force: true)
|
251
|
+
path = model_file.full_path(version)
|
252
|
+
full_path = adapter.save_model_file(path)
|
253
|
+
model_file.upload(full_path)
|
254
|
+
|
255
|
+
model_file.save
|
256
|
+
self.model_file = model_file
|
257
|
+
cleanup
|
258
|
+
end
|
259
|
+
|
260
|
+
def feature_names
|
261
|
+
adapter.feature_names
|
262
|
+
end
|
263
|
+
|
264
|
+
def cleanup!
|
265
|
+
get_model_file&.cleanup!
|
266
|
+
end
|
267
|
+
|
268
|
+
def cleanup
|
269
|
+
get_model_file&.cleanup(files_to_keep)
|
270
|
+
end
|
271
|
+
|
272
|
+
def loaded?
|
273
|
+
model_file = get_model_file
|
274
|
+
return false if model_file.persisted? && !File.exist?(model_file.full_path.to_s)
|
275
|
+
|
276
|
+
file_exists = true
|
277
|
+
if model_file.present? && model_file.persisted? && model_file.full_path.present?
|
278
|
+
file_exists = File.exist?(model_file.full_path)
|
39
279
|
end
|
280
|
+
|
281
|
+
loaded = adapter.loaded?
|
282
|
+
load_model_file unless loaded
|
283
|
+
file_exists && adapter.loaded?
|
40
284
|
end
|
41
285
|
|
42
|
-
def
|
43
|
-
|
286
|
+
def model_changed?
|
287
|
+
return false unless is_fit?
|
288
|
+
return true if inference_version.nil?
|
289
|
+
return true if model_file.present? && !model_file.persisted?
|
290
|
+
return true if model_file.present? && model_file.fit? && inference_version.nil?
|
291
|
+
|
292
|
+
adapter.model_changed?(inference_version.sha)
|
293
|
+
end
|
294
|
+
|
295
|
+
def feature_importances
|
296
|
+
adapter.feature_importances
|
297
|
+
end
|
298
|
+
|
299
|
+
def fit_in_batches?
|
300
|
+
retraining_job.present? && retraining_job.batch_mode == true
|
301
|
+
end
|
302
|
+
|
303
|
+
def fit(tuning: false, x_train: nil, y_train: nil, x_valid: nil, y_valid: nil, &progress_block)
|
304
|
+
return fit_in_batches(**batch_args.merge!(tuning: tuning), &progress_block) if fit_in_batches?
|
305
|
+
|
306
|
+
dataset.refresh
|
307
|
+
adapter.fit(tuning: tuning, x_train: x_train, y_train: y_train, x_valid: x_valid, y_valid: y_valid, &progress_block)
|
308
|
+
@is_fit = true
|
309
|
+
end
|
310
|
+
|
311
|
+
def batch_args
|
312
|
+
defaults = {
|
313
|
+
batch_size: 1024,
|
314
|
+
batch_overlap: 3,
|
315
|
+
batch_key: nil,
|
316
|
+
}
|
317
|
+
overrides = { batch_size: retraining_job&.batch_size, batch_overlap: retraining_job&.batch_overlap, batch_key: retraining_job&.batch_key }.compact
|
318
|
+
defaults.merge!(overrides)
|
319
|
+
end
|
320
|
+
|
321
|
+
def batch_mode
|
322
|
+
retraining_job&.batch_mode || false
|
323
|
+
end
|
324
|
+
|
325
|
+
def prepare_callbacks(tune_started_at)
|
326
|
+
adapter.prepare_callbacks(tune_started_at)
|
327
|
+
end
|
328
|
+
|
329
|
+
def after_tuning
|
330
|
+
adapter.after_tuning
|
331
|
+
end
|
332
|
+
|
333
|
+
def fit_in_batches(tuning: false, batch_size: nil, batch_overlap: nil, batch_key: nil, checkpoint_dir: Rails.root.join("tmp", "xgboost_checkpoints"), &progress_block)
|
334
|
+
adapter.fit_in_batches(tuning: tuning, batch_size: batch_size, batch_overlap: batch_overlap, batch_key: batch_key, checkpoint_dir: checkpoint_dir, &progress_block)
|
335
|
+
@is_fit = true
|
336
|
+
end
|
337
|
+
|
338
|
+
attr_accessor :is_fit
|
339
|
+
|
340
|
+
def is_fit?
|
341
|
+
model_file = get_model_file
|
342
|
+
return true if model_file.present? && model_file.fit?
|
343
|
+
|
344
|
+
adapter.is_fit?
|
345
|
+
end
|
346
|
+
|
347
|
+
def deployable?
|
348
|
+
cannot_deploy_reasons.none?
|
349
|
+
end
|
350
|
+
|
351
|
+
def decode_labels(ys, col: nil)
|
352
|
+
dataset.decode_labels(ys, col: col)
|
353
|
+
end
|
354
|
+
|
355
|
+
def evaluate(y_pred: nil, y_true: nil, x_true: nil, evaluator: nil)
|
356
|
+
evaluator ||= self.evaluator
|
357
|
+
if y_pred.nil?
|
358
|
+
inputs = default_evaluation_inputs
|
359
|
+
y_pred = inputs[:y_pred]
|
360
|
+
y_true = inputs[:y_true]
|
361
|
+
x_true = inputs[:x_true]
|
362
|
+
end
|
363
|
+
EasyML::Core::ModelEvaluator.evaluate(model: self, y_pred: y_pred, y_true: y_true, x_true: x_true, evaluator: evaluator)
|
364
|
+
end
|
365
|
+
|
366
|
+
def evaluator
|
367
|
+
instance_variable_get(:@evaluator) || default_evaluator
|
368
|
+
end
|
369
|
+
|
370
|
+
def default_evaluator
|
371
|
+
return nil unless task.present?
|
372
|
+
|
373
|
+
EasyML::Core::ModelEvaluator.default_evaluator(task)
|
374
|
+
end
|
375
|
+
|
376
|
+
def get_params
|
377
|
+
@hyperparameters.to_h
|
378
|
+
end
|
379
|
+
|
380
|
+
def evals
|
381
|
+
last_run&.metrics || {}
|
382
|
+
end
|
383
|
+
|
384
|
+
def metric_accessor(metric)
|
385
|
+
metrics = last_run.metrics.symbolize_keys
|
386
|
+
metrics.dig(metric.to_sym)
|
387
|
+
end
|
388
|
+
|
389
|
+
EasyML::Core::ModelEvaluator.metrics.each do |metric_name|
|
390
|
+
define_method metric_name do
|
391
|
+
metric_accessor(metric_name)
|
392
|
+
end
|
393
|
+
end
|
394
|
+
|
395
|
+
EasyML::Core::ModelEvaluator.callbacks = lambda do |metric_name|
|
396
|
+
EasyML::Model.define_method metric_name do
|
397
|
+
metric_accessor(metric_name)
|
398
|
+
end
|
399
|
+
end
|
400
|
+
|
401
|
+
def allowed_metrics
|
402
|
+
EasyML::Core::ModelEvaluator.metrics(task).map(&:to_s)
|
403
|
+
end
|
404
|
+
|
405
|
+
def default_metrics
|
406
|
+
return [] unless task.present?
|
407
|
+
|
408
|
+
case task.to_sym
|
409
|
+
when :regression
|
410
|
+
%w[mean_absolute_error mean_squared_error root_mean_squared_error r2_score]
|
411
|
+
when :classification
|
412
|
+
%w[accuracy_score precision_score recall_score f1_score]
|
413
|
+
else
|
414
|
+
[]
|
415
|
+
end
|
416
|
+
end
|
417
|
+
|
418
|
+
def self.constants
|
419
|
+
{
|
420
|
+
objectives: objectives_by_model_type,
|
421
|
+
metrics: metrics_by_task,
|
422
|
+
tasks: TASK_TYPES,
|
423
|
+
timezone: EasyML::Configuration.timezone_label,
|
424
|
+
retraining_job_constants: EasyML::RetrainingJob.constants,
|
425
|
+
tuner_job_constants: EasyML::TunerJob.constants,
|
426
|
+
}
|
427
|
+
end
|
428
|
+
|
429
|
+
def self.metrics_by_task
|
430
|
+
EasyML::Core::ModelEvaluator.metrics_by_task
|
431
|
+
end
|
432
|
+
|
433
|
+
def self.objectives_by_model_type
|
434
|
+
MODEL_OPTIONS.inject({}) do |h, (k, v)|
|
435
|
+
h.tap do
|
436
|
+
h[k] = v.constantize.const_get(:OBJECTIVES_FRONTEND)
|
437
|
+
end
|
438
|
+
end.deep_symbolize_keys
|
439
|
+
end
|
440
|
+
|
441
|
+
def attributes
|
442
|
+
super.merge!(
|
443
|
+
hyperparameters: hyperparameters.to_h,
|
444
|
+
)
|
445
|
+
end
|
446
|
+
|
447
|
+
class CannotdeployError < StandardError
|
448
|
+
end
|
449
|
+
|
450
|
+
def deploy(async: true)
|
451
|
+
last_run.deploy(async: async)
|
452
|
+
end
|
453
|
+
|
454
|
+
def actually_deploy
|
455
|
+
raise CannotdeployError, cannot_deploy_reasons.first if cannot_deploy_reasons.any?
|
456
|
+
|
457
|
+
# Prepare the inference model by freezing + saving the model, dataset, and datasource
|
458
|
+
# (This creates ModelHistory, DatasetHistory, etc)
|
459
|
+
save_model_file
|
460
|
+
self.sha = model_file.sha
|
461
|
+
save
|
462
|
+
dataset.upload_remote_files
|
463
|
+
snapshot.tap do
|
464
|
+
# Prepare the model to be retrained (reset values so they don't conflict with our snapshotted version)
|
465
|
+
bump_version(force: true)
|
466
|
+
dataset.bump_versions(version)
|
467
|
+
self.model_file = new_model_file!
|
468
|
+
save
|
469
|
+
end
|
470
|
+
end
|
471
|
+
|
472
|
+
def cannot_deploy_reasons
|
473
|
+
[
|
474
|
+
is_fit? ? nil : "Model has not been trained",
|
475
|
+
dataset.target.present? ? nil : "Dataset has no target",
|
476
|
+
!dataset.datasource.in_memory? ? nil : "Cannot perform inference using an in-memory datasource",
|
477
|
+
].compact
|
478
|
+
end
|
479
|
+
|
480
|
+
def root_dir=(value)
|
481
|
+
raise "Cannot override value of root_dir!" unless value.to_s == root_dir.to_s
|
482
|
+
|
483
|
+
write_attribute(:root_dir, value)
|
484
|
+
end
|
485
|
+
|
486
|
+
def set_root_dir
|
487
|
+
write_attribute(:root_dir, root_dir)
|
488
|
+
end
|
489
|
+
|
490
|
+
def root_dir
|
491
|
+
EasyML::Engine.root_dir.join("models").join(underscored_name).to_s
|
492
|
+
end
|
493
|
+
|
494
|
+
def load_model(force: false)
|
495
|
+
download_model_file(force: force)
|
496
|
+
load_model_file
|
497
|
+
end
|
498
|
+
|
499
|
+
def metrics=(value)
|
500
|
+
value = [value] unless value.is_a?(Array)
|
501
|
+
value = value.map(&:to_s)
|
502
|
+
value = value.uniq
|
503
|
+
@metrics = value
|
504
|
+
end
|
505
|
+
|
506
|
+
def adapter
|
507
|
+
@adapter ||= begin
|
508
|
+
adapter_class = MODEL_OPTIONS[model_type]
|
509
|
+
raise "Don't know how to use model adapter #{model_type}!" unless adapter_class.present?
|
510
|
+
|
511
|
+
adapter_class.constantize.new(self)
|
512
|
+
end
|
44
513
|
end
|
45
514
|
|
46
515
|
private
|
47
516
|
|
517
|
+
def default_evaluation_inputs
|
518
|
+
x_true, y_true = dataset.test(split_ys: true)
|
519
|
+
y_pred = predict(x_true)
|
520
|
+
{
|
521
|
+
x_true: x_true,
|
522
|
+
y_true: y_true,
|
523
|
+
y_pred: y_pred,
|
524
|
+
}
|
525
|
+
end
|
526
|
+
|
527
|
+
def underscored_name
|
528
|
+
name.gsub(/\s{2,}/, " ").gsub(/\s/, "_").downcase
|
529
|
+
end
|
530
|
+
|
531
|
+
def get_model_file
|
532
|
+
model_file || new_model_file!
|
533
|
+
end
|
534
|
+
|
535
|
+
def new_model_file!
|
536
|
+
build_model_file(
|
537
|
+
root_dir: root_dir,
|
538
|
+
model: self,
|
539
|
+
s3_bucket: EasyML::Configuration.s3_bucket,
|
540
|
+
s3_region: EasyML::Configuration.s3_region,
|
541
|
+
s3_access_key_id: EasyML::Configuration.s3_access_key_id,
|
542
|
+
s3_secret_access_key: EasyML::Configuration.s3_secret_access_key,
|
543
|
+
s3_prefix: prefix,
|
544
|
+
)
|
545
|
+
end
|
546
|
+
|
547
|
+
def prefix
|
548
|
+
s3_prefix = EasyML::Configuration.s3_prefix
|
549
|
+
s3_prefix.present? ? File.join(s3_prefix, name) : name
|
550
|
+
end
|
551
|
+
|
552
|
+
def load_model!
|
553
|
+
load_model(force: true)
|
554
|
+
load_dataset
|
555
|
+
end
|
556
|
+
|
557
|
+
def load_dataset
|
558
|
+
dataset.load_dataset
|
559
|
+
end
|
560
|
+
|
561
|
+
def load_model_file
|
562
|
+
return unless model_file&.full_path && File.exist?(model_file.full_path)
|
563
|
+
|
564
|
+
adapter.load_model_file(model_file.full_path)
|
565
|
+
end
|
566
|
+
|
567
|
+
def download_model_file(force: false)
|
568
|
+
return unless persisted?
|
569
|
+
return if loaded? && !force
|
570
|
+
|
571
|
+
get_model_file.download
|
572
|
+
end
|
573
|
+
|
48
574
|
def files_to_keep
|
49
|
-
|
575
|
+
inference_models = EasyML::ModelHistory.deployed
|
576
|
+
training_models = EasyML::Model.all
|
50
577
|
|
51
|
-
|
52
|
-
|
53
|
-
self.class.where(name: live.name).where(is_live: false).order(created_at: :desc).limit(live.name == name ? 4 : 5)
|
54
|
-
end
|
578
|
+
([self] + training_models + inference_models).compact.map(&:model_file).compact.map(&:full_path).uniq
|
579
|
+
end
|
55
580
|
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
.where("created_at >= ?", 2.days.ago)
|
61
|
-
.order(created_at: :desc)
|
62
|
-
.group_by(&:name)
|
63
|
-
.flat_map { |_, models| models.take(5) }
|
581
|
+
def underscored_name
|
582
|
+
name = self.name || self.class.name.split("::").last
|
583
|
+
name.gsub(/\s{2,}/, " ").gsub(/\s/, "_").downcase
|
584
|
+
end
|
64
585
|
|
65
|
-
|
586
|
+
def set_defaults
|
587
|
+
self.model_type ||= "xgboost"
|
588
|
+
self.status ||= :training
|
589
|
+
self.metrics ||= default_metrics
|
590
|
+
end
|
591
|
+
|
592
|
+
def validate_metrics_allowed
|
593
|
+
unknown_metrics = metrics.select { |metric| allowed_metrics.exclude?(metric) }
|
594
|
+
return unless unknown_metrics.any?
|
595
|
+
|
596
|
+
errors.add(:metrics,
|
597
|
+
"don't know how to handle #{"metrics".pluralize(unknown_metrics)} #{unknown_metrics.join(", ")}, use EasyML::Core::ModelEvaluator.register(:name, Evaluator, :regression|:classification)")
|
66
598
|
end
|
67
599
|
end
|
68
600
|
end
|
601
|
+
|
602
|
+
require_relative "models/xgboost"
|