easy_ml 0.1.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/README.md +270 -0
- data/Rakefile +12 -0
- data/app/models/easy_ml/model.rb +59 -0
- data/app/models/easy_ml/models/xgboost.rb +9 -0
- data/app/models/easy_ml/models.rb +5 -0
- data/lib/easy_ml/core/model.rb +29 -0
- data/lib/easy_ml/core/model_core.rb +181 -0
- data/lib/easy_ml/core/model_evaluator.rb +137 -0
- data/lib/easy_ml/core/models/hyperparameters/base.rb +34 -0
- data/lib/easy_ml/core/models/hyperparameters/xgboost.rb +19 -0
- data/lib/easy_ml/core/models/hyperparameters.rb +8 -0
- data/lib/easy_ml/core/models/xgboost.rb +10 -0
- data/lib/easy_ml/core/models/xgboost_core.rb +220 -0
- data/lib/easy_ml/core/models.rb +10 -0
- data/lib/easy_ml/core/tuner/adapters/base_adapter.rb +63 -0
- data/lib/easy_ml/core/tuner/adapters/xgboost_adapter.rb +50 -0
- data/lib/easy_ml/core/tuner/adapters.rb +10 -0
- data/lib/easy_ml/core/tuner.rb +105 -0
- data/lib/easy_ml/core/uploaders/model_uploader.rb +24 -0
- data/lib/easy_ml/core/uploaders.rb +7 -0
- data/lib/easy_ml/core.rb +9 -0
- data/lib/easy_ml/core_ext/pathname.rb +9 -0
- data/lib/easy_ml/core_ext.rb +5 -0
- data/lib/easy_ml/data/dataloader.rb +6 -0
- data/lib/easy_ml/data/dataset/data/preprocessor/statistics.json +31 -0
- data/lib/easy_ml/data/dataset/data/sample_info.json +1 -0
- data/lib/easy_ml/data/dataset/dataset/files/sample_info.json +1 -0
- data/lib/easy_ml/data/dataset/splits/file_split.rb +140 -0
- data/lib/easy_ml/data/dataset/splits/in_memory_split.rb +49 -0
- data/lib/easy_ml/data/dataset/splits/split.rb +98 -0
- data/lib/easy_ml/data/dataset/splits.rb +11 -0
- data/lib/easy_ml/data/dataset/splitters/date_splitter.rb +43 -0
- data/lib/easy_ml/data/dataset/splitters.rb +9 -0
- data/lib/easy_ml/data/dataset.rb +430 -0
- data/lib/easy_ml/data/datasource/datasource_factory.rb +60 -0
- data/lib/easy_ml/data/datasource/file_datasource.rb +40 -0
- data/lib/easy_ml/data/datasource/merged_datasource.rb +64 -0
- data/lib/easy_ml/data/datasource/polars_datasource.rb +41 -0
- data/lib/easy_ml/data/datasource/s3_datasource.rb +89 -0
- data/lib/easy_ml/data/datasource.rb +33 -0
- data/lib/easy_ml/data/preprocessor/preprocessor.rb +205 -0
- data/lib/easy_ml/data/preprocessor/simple_imputer.rb +403 -0
- data/lib/easy_ml/data/preprocessor/utils.rb +17 -0
- data/lib/easy_ml/data/preprocessor.rb +238 -0
- data/lib/easy_ml/data/utils.rb +50 -0
- data/lib/easy_ml/data.rb +8 -0
- data/lib/easy_ml/deployment.rb +5 -0
- data/lib/easy_ml/engine.rb +26 -0
- data/lib/easy_ml/initializers/inflections.rb +4 -0
- data/lib/easy_ml/logging.rb +38 -0
- data/lib/easy_ml/railtie/generators/migration/migration_generator.rb +42 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_models.rb.tt +23 -0
- data/lib/easy_ml/support/age.rb +27 -0
- data/lib/easy_ml/support/est.rb +1 -0
- data/lib/easy_ml/support/file_rotate.rb +23 -0
- data/lib/easy_ml/support/git_ignorable.rb +66 -0
- data/lib/easy_ml/support/synced_directory.rb +134 -0
- data/lib/easy_ml/support/utc.rb +1 -0
- data/lib/easy_ml/support.rb +10 -0
- data/lib/easy_ml/trainer.rb +92 -0
- data/lib/easy_ml/transforms.rb +29 -0
- data/lib/easy_ml/version.rb +5 -0
- data/lib/easy_ml.rb +23 -0
- metadata +353 -0
@@ -0,0 +1,137 @@
|
|
1
|
+
module EasyML
|
2
|
+
module Core
|
3
|
+
class ModelEvaluator
|
4
|
+
require "numo/narray"
|
5
|
+
|
6
|
+
EVALUATORS = {
|
7
|
+
mean_absolute_error: lambda { |y_pred, y_true|
|
8
|
+
(Numo::DFloat.cast(y_pred) - Numo::DFloat.cast(y_true)).abs.mean
|
9
|
+
},
|
10
|
+
mean_squared_error: lambda { |y_pred, y_true|
|
11
|
+
((Numo::DFloat.cast(y_pred) - Numo::DFloat.cast(y_true))**2).mean
|
12
|
+
},
|
13
|
+
root_mean_squared_error: lambda { |y_pred, y_true|
|
14
|
+
Math.sqrt(((Numo::DFloat.cast(y_pred) - Numo::DFloat.cast(y_true))**2).mean)
|
15
|
+
},
|
16
|
+
r2_score: lambda { |y_pred, y_true|
|
17
|
+
# Convert inputs to Numo::DFloat for numerical operations
|
18
|
+
y_true = Numo::DFloat.cast(y_true)
|
19
|
+
y_pred = Numo::DFloat.cast(y_pred)
|
20
|
+
|
21
|
+
# Calculate the mean of the true values
|
22
|
+
mean_y = y_true.mean
|
23
|
+
|
24
|
+
# Calculate Total Sum of Squares (SS_tot)
|
25
|
+
ss_tot = ((y_true - mean_y)**2).sum
|
26
|
+
|
27
|
+
# Calculate Residual Sum of Squares (SS_res)
|
28
|
+
ss_res = ((y_true - y_pred)**2).sum
|
29
|
+
|
30
|
+
# Handle the edge case where SS_tot is zero
|
31
|
+
if ss_tot.zero?
|
32
|
+
if ss_res.zero?
|
33
|
+
# Perfect prediction when both SS_tot and SS_res are zero
|
34
|
+
1.0
|
35
|
+
else
|
36
|
+
# Undefined R² when SS_tot is zero but SS_res is not
|
37
|
+
Float::NAN
|
38
|
+
end
|
39
|
+
else
|
40
|
+
# Calculate R²
|
41
|
+
1 - (ss_res / ss_tot)
|
42
|
+
end
|
43
|
+
},
|
44
|
+
accuracy_score: lambda { |y_pred, y_true|
|
45
|
+
y_pred = Numo::Int32.cast(y_pred)
|
46
|
+
y_true = Numo::Int32.cast(y_true)
|
47
|
+
y_pred.eq(y_true).count_true.to_f / y_pred.size
|
48
|
+
},
|
49
|
+
precision_score: lambda { |y_pred, y_true|
|
50
|
+
y_pred = Numo::Int32.cast(y_pred)
|
51
|
+
y_true = Numo::Int32.cast(y_true)
|
52
|
+
true_positives = (y_pred.eq(1) & y_true.eq(1)).count_true
|
53
|
+
predicted_positives = y_pred.eq(1).count_true
|
54
|
+
return 0 if predicted_positives == 0
|
55
|
+
|
56
|
+
true_positives.to_f / predicted_positives
|
57
|
+
},
|
58
|
+
recall_score: lambda { |y_pred, y_true|
|
59
|
+
y_pred = Numo::Int32.cast(y_pred)
|
60
|
+
y_true = Numo::Int32.cast(y_true)
|
61
|
+
true_positives = (y_pred.eq(1) & y_true.eq(1)).count_true
|
62
|
+
actual_positives = y_true.eq(1).count_true
|
63
|
+
true_positives.to_f / actual_positives
|
64
|
+
},
|
65
|
+
f1_score: lambda { |y_pred, y_true|
|
66
|
+
precision = EVALUATORS[:precision_score].call(y_pred, y_true) || 0
|
67
|
+
recall = EVALUATORS[:recall_score].call(y_pred, y_true) || 0
|
68
|
+
return 0 unless (precision + recall) > 0
|
69
|
+
|
70
|
+
2 * (precision * recall) / (precision + recall)
|
71
|
+
}
|
72
|
+
}
|
73
|
+
|
74
|
+
class << self
|
75
|
+
def evaluate(model: nil, y_pred: nil, y_true: nil, x_true: nil, evaluator: nil)
|
76
|
+
y_pred = normalize_input(y_pred)
|
77
|
+
y_true = normalize_input(y_true)
|
78
|
+
check_size(y_pred, y_true)
|
79
|
+
|
80
|
+
metrics_results = {}
|
81
|
+
|
82
|
+
model.metrics.each do |metric|
|
83
|
+
if metric.is_a?(Module) || metric.is_a?(Class)
|
84
|
+
unless metric.respond_to?(:evaluate)
|
85
|
+
raise "Metric #{metric} must respond to #evaluate in order to be used as a custom evaluator"
|
86
|
+
end
|
87
|
+
|
88
|
+
metrics_results[metric.name] = metric.evaluate(y_pred, y_true)
|
89
|
+
elsif EVALUATORS.key?(metric.to_sym)
|
90
|
+
metrics_results[metric.to_sym] =
|
91
|
+
EVALUATORS[metric.to_sym].call(y_pred, y_true)
|
92
|
+
end
|
93
|
+
end
|
94
|
+
|
95
|
+
if evaluator.present?
|
96
|
+
if evaluator.is_a?(Class)
|
97
|
+
response = evaluator.new.evaluate(y_pred: y_pred, y_true: y_true, x_true: x_true)
|
98
|
+
elsif evaluator.respond_to?(:evaluate)
|
99
|
+
response = evaluator.evaluate(y_pred: y_pred, y_true: y_true, x_true: x_true)
|
100
|
+
elsif evaluator.respond_to?(:call)
|
101
|
+
response = evaluator.call(y_pred: y_pred, y_true: y_true, x_true: x_true)
|
102
|
+
else
|
103
|
+
raise "Don't know how to use CustomEvaluator. Must be a class that responds to evaluate or lambda"
|
104
|
+
end
|
105
|
+
|
106
|
+
if response.is_a?(Hash)
|
107
|
+
metrics_results.merge!(response)
|
108
|
+
else
|
109
|
+
metrics_results[:custom] = response
|
110
|
+
end
|
111
|
+
end
|
112
|
+
|
113
|
+
metrics_results
|
114
|
+
end
|
115
|
+
|
116
|
+
def check_size(y_pred, y_true)
|
117
|
+
raise ArgumentError, "Different sizes" if y_true.size != y_pred.size
|
118
|
+
end
|
119
|
+
|
120
|
+
def normalize_input(input)
|
121
|
+
case input
|
122
|
+
when Polars::DataFrame
|
123
|
+
if input.columns.count > 1
|
124
|
+
raise ArgumentError, "Don't know how to evaluate input with multiple columns: #{input}"
|
125
|
+
end
|
126
|
+
|
127
|
+
normalize_input(input[input.columns.first])
|
128
|
+
when Polars::Series, Array
|
129
|
+
Numo::DFloat.cast(input)
|
130
|
+
else
|
131
|
+
raise ArgumentError, "Don't know how to evaluate model with y_pred type #{input.class}"
|
132
|
+
end
|
133
|
+
end
|
134
|
+
end
|
135
|
+
end
|
136
|
+
end
|
137
|
+
end
|
@@ -0,0 +1,34 @@
|
|
1
|
+
module EasyML
|
2
|
+
module Models
|
3
|
+
module Hyperparameters
|
4
|
+
class Base
|
5
|
+
include GlueGun::DSL
|
6
|
+
|
7
|
+
attribute :learning_rate, :float, default: 0.01
|
8
|
+
attribute :max_iterations, :integer, default: 100
|
9
|
+
attribute :batch_size, :integer, default: 32
|
10
|
+
attribute :regularization, :float, default: 0.0001
|
11
|
+
|
12
|
+
def to_h
|
13
|
+
attributes
|
14
|
+
end
|
15
|
+
|
16
|
+
def merge(other)
|
17
|
+
return self if other.nil?
|
18
|
+
|
19
|
+
other_hash = other.is_a?(Hyperparameters) ? other.to_h : other
|
20
|
+
merged_hash = to_h.merge(other_hash)
|
21
|
+
self.class.new(**merged_hash)
|
22
|
+
end
|
23
|
+
|
24
|
+
def [](key)
|
25
|
+
send(key) if respond_to?(key)
|
26
|
+
end
|
27
|
+
|
28
|
+
def []=(key, value)
|
29
|
+
send("#{key}=", value) if respond_to?("#{key}=")
|
30
|
+
end
|
31
|
+
end
|
32
|
+
end
|
33
|
+
end
|
34
|
+
end
|
@@ -0,0 +1,19 @@
|
|
1
|
+
module EasyML
|
2
|
+
module Models
|
3
|
+
module Hyperparameters
|
4
|
+
class XGBoost < Base
|
5
|
+
include GlueGun::DSL
|
6
|
+
|
7
|
+
attribute :learning_rate, :float, default: 0.1
|
8
|
+
attribute :max_depth, :integer, default: 6
|
9
|
+
attribute :n_estimators, :integer, default: 100
|
10
|
+
attribute :booster, :string, default: "gbtree"
|
11
|
+
attribute :objective, :string, default: "reg:squarederror"
|
12
|
+
|
13
|
+
validates :objective,
|
14
|
+
inclusion: { in: %w[binary:logistic binary:hinge multi:softmax multi:softprob reg:squarederror
|
15
|
+
reg:logistic] }
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
19
|
+
end
|
@@ -0,0 +1,220 @@
|
|
1
|
+
require "wandb"
|
2
|
+
module EasyML
|
3
|
+
module Core
|
4
|
+
module Models
|
5
|
+
module XGBoostCore
|
6
|
+
OBJECTIVES = {
|
7
|
+
classification: {
|
8
|
+
binary: %w[binary:logistic binary:hinge],
|
9
|
+
multi_class: %w[multi:softmax multi:softprob]
|
10
|
+
},
|
11
|
+
regression: %w[reg:squarederror reg:logistic]
|
12
|
+
}
|
13
|
+
|
14
|
+
def self.included(base)
|
15
|
+
base.class_eval do
|
16
|
+
attribute :evaluator
|
17
|
+
|
18
|
+
dependency :callbacks, { array: true } do |dep|
|
19
|
+
dep.option :wandb do |opt|
|
20
|
+
opt.set_class Wandb::XGBoostCallback
|
21
|
+
opt.bind_attribute :log_model, default: false
|
22
|
+
opt.bind_attribute :log_feature_importance, default: true
|
23
|
+
opt.bind_attribute :importance_type, default: "gain"
|
24
|
+
opt.bind_attribute :define_metric, default: true
|
25
|
+
opt.bind_attribute :project_name
|
26
|
+
end
|
27
|
+
end
|
28
|
+
|
29
|
+
dependency :hyperparameters do |dep|
|
30
|
+
dep.set_class EasyML::Models::Hyperparameters::XGBoost
|
31
|
+
dep.bind_attribute :batch_size, default: 32
|
32
|
+
dep.bind_attribute :learning_rate, default: 1.1
|
33
|
+
dep.bind_attribute :max_depth, default: 6
|
34
|
+
dep.bind_attribute :n_estimators, default: 100
|
35
|
+
dep.bind_attribute :booster, default: "gbtree"
|
36
|
+
dep.bind_attribute :objective, default: "reg:squarederror"
|
37
|
+
end
|
38
|
+
end
|
39
|
+
end
|
40
|
+
|
41
|
+
attr_accessor :model, :booster
|
42
|
+
|
43
|
+
def predict(xs)
|
44
|
+
raise "No trained model! Train a model before calling predict" unless @booster.present?
|
45
|
+
raise "Cannot predict on nil — XGBoost" if xs.nil?
|
46
|
+
|
47
|
+
y_pred = @booster.predict(preprocess(xs))
|
48
|
+
|
49
|
+
case task.to_sym
|
50
|
+
when :classification
|
51
|
+
to_classification(y_pred)
|
52
|
+
else
|
53
|
+
y_pred
|
54
|
+
end
|
55
|
+
end
|
56
|
+
|
57
|
+
def predict_proba(data)
|
58
|
+
dmat = DMatrix.new(data)
|
59
|
+
y_pred = @booster.predict(dmat)
|
60
|
+
|
61
|
+
if y_pred.first.is_a?(Array)
|
62
|
+
# multiple classes
|
63
|
+
y_pred
|
64
|
+
else
|
65
|
+
y_pred.map { |v| [1 - v, v] }
|
66
|
+
end
|
67
|
+
end
|
68
|
+
|
69
|
+
def load(path = nil)
|
70
|
+
path ||= file
|
71
|
+
path = path&.file&.file if path.class.ancestors.include?(CarrierWave::Uploader::Base)
|
72
|
+
|
73
|
+
raise "No existing model at #{path}" unless File.exist?(path)
|
74
|
+
|
75
|
+
initialize_model do
|
76
|
+
booster_class.new(params: hyperparameters.to_h, model_file: path)
|
77
|
+
end
|
78
|
+
end
|
79
|
+
|
80
|
+
def _save_model_file(path)
|
81
|
+
puts "XGBoost received path #{path}"
|
82
|
+
@booster.save_model(path)
|
83
|
+
end
|
84
|
+
|
85
|
+
def feature_importances
|
86
|
+
@model.booster.feature_names.zip(@model.feature_importances).to_h
|
87
|
+
end
|
88
|
+
|
89
|
+
def base_model
|
90
|
+
::XGBoost
|
91
|
+
end
|
92
|
+
|
93
|
+
def customize_callbacks
|
94
|
+
yield callbacks
|
95
|
+
end
|
96
|
+
|
97
|
+
private
|
98
|
+
|
99
|
+
def booster_class
|
100
|
+
::XGBoost::Booster
|
101
|
+
end
|
102
|
+
|
103
|
+
def d_matrix_class
|
104
|
+
::XGBoost::DMatrix
|
105
|
+
end
|
106
|
+
|
107
|
+
def model_class
|
108
|
+
::XGBoost::Model
|
109
|
+
end
|
110
|
+
|
111
|
+
def train
|
112
|
+
validate_objective
|
113
|
+
|
114
|
+
xs = xs.to_a.map(&:values)
|
115
|
+
ys = ys.to_a.map(&:values)
|
116
|
+
dtrain = d_matrix_class.new(xs, label: ys)
|
117
|
+
@model = base_model.train(hyperparameters.to_h, dtrain, callbacks: callbacks)
|
118
|
+
end
|
119
|
+
|
120
|
+
def train_in_batches
|
121
|
+
validate_objective
|
122
|
+
|
123
|
+
# Initialize the model with the first batch
|
124
|
+
@model = nil
|
125
|
+
@booster = nil
|
126
|
+
x_valid, y_valid = dataset.valid(split_ys: true)
|
127
|
+
x_train, y_train = dataset.train(split_ys: true)
|
128
|
+
fit_batch(x_train, y_train, x_valid, y_valid)
|
129
|
+
end
|
130
|
+
|
131
|
+
def _preprocess(df)
|
132
|
+
df.to_a.map do |row|
|
133
|
+
row.values.map do |value|
|
134
|
+
case value
|
135
|
+
when Time
|
136
|
+
value.to_i # Convert Time to Unix timestamp
|
137
|
+
when Date
|
138
|
+
value.to_time.to_i # Convert Date to Unix timestamp
|
139
|
+
when String
|
140
|
+
value
|
141
|
+
when TrueClass, FalseClass
|
142
|
+
value ? 1.0 : 0.0 # Convert booleans to 1.0 and 0.0
|
143
|
+
when Integer
|
144
|
+
value
|
145
|
+
else
|
146
|
+
value.to_f # Ensure everything else is converted to a float
|
147
|
+
end
|
148
|
+
end
|
149
|
+
end
|
150
|
+
end
|
151
|
+
|
152
|
+
def preprocess(xs, ys = nil)
|
153
|
+
column_names = xs.columns
|
154
|
+
xs = _preprocess(xs)
|
155
|
+
ys = ys.nil? ? nil : _preprocess(ys).flatten
|
156
|
+
kwargs = { label: ys }.compact
|
157
|
+
::XGBoost::DMatrix.new(xs, **kwargs).tap do |dmat|
|
158
|
+
dmat.instance_variable_set(:@feature_names, column_names)
|
159
|
+
end
|
160
|
+
end
|
161
|
+
|
162
|
+
def initialize_model
|
163
|
+
@model = model_class.new(n_estimators: @hyperparameters.to_h.dig(:n_estimators))
|
164
|
+
@booster = yield
|
165
|
+
@model.instance_variable_set(:@booster, @booster)
|
166
|
+
end
|
167
|
+
|
168
|
+
def validate_objective
|
169
|
+
objective = hyperparameters.objective
|
170
|
+
unless task.present?
|
171
|
+
raise ArgumentError,
|
172
|
+
"cannot train model without task. Please specify either regression or classification (model.task = :regression)"
|
173
|
+
end
|
174
|
+
|
175
|
+
case task.to_sym
|
176
|
+
when :classification
|
177
|
+
_, ys = dataset.data(split_ys: true)
|
178
|
+
classification_type = ys[ys.columns.first].uniq.count <= 2 ? :binary : :multi_class
|
179
|
+
allowed_objectives = OBJECTIVES[:classification][classification_type]
|
180
|
+
else
|
181
|
+
allowed_objectives = OBJECTIVES[task.to_sym]
|
182
|
+
end
|
183
|
+
return if allowed_objectives.map(&:to_sym).include?(objective.to_sym)
|
184
|
+
|
185
|
+
raise ArgumentError,
|
186
|
+
"cannot use #{objective} for #{task} task. Allowed objectives are: #{allowed_objectives.join(", ")}"
|
187
|
+
end
|
188
|
+
|
189
|
+
def fit_batch(x_train, y_train, x_valid, y_valid)
|
190
|
+
d_train = preprocess(x_train, y_train)
|
191
|
+
d_valid = preprocess(x_valid, y_valid)
|
192
|
+
|
193
|
+
evals = [[d_train, "train"], [d_valid, "eval"]]
|
194
|
+
|
195
|
+
# # If this is the first batch, create the booster
|
196
|
+
if @booster.nil?
|
197
|
+
initialize_model do
|
198
|
+
base_model.train(@hyperparameters.to_h, d_train,
|
199
|
+
num_boost_round: @hyperparameters.to_h.dig("n_estimators"), evals: evals, callbacks: callbacks)
|
200
|
+
end
|
201
|
+
else
|
202
|
+
# Update the existing booster with the new batch
|
203
|
+
@model.update(d_train)
|
204
|
+
end
|
205
|
+
end
|
206
|
+
|
207
|
+
def to_classification(y_pred)
|
208
|
+
if y_pred.first.is_a?(Array)
|
209
|
+
# multiple classes
|
210
|
+
y_pred.map do |v|
|
211
|
+
v.map.with_index.max_by { |v2, _| v2 }.last
|
212
|
+
end
|
213
|
+
else
|
214
|
+
y_pred.map { |v| v > 0.5 ? 1 : 0 }
|
215
|
+
end
|
216
|
+
end
|
217
|
+
end
|
218
|
+
end
|
219
|
+
end
|
220
|
+
end
|
@@ -0,0 +1,63 @@
|
|
1
|
+
module EasyML
|
2
|
+
module Core
|
3
|
+
class Tuner
|
4
|
+
module Adapters
|
5
|
+
class BaseAdapter
|
6
|
+
include GlueGun::DSL
|
7
|
+
|
8
|
+
def defaults
|
9
|
+
{}
|
10
|
+
end
|
11
|
+
|
12
|
+
attribute :model
|
13
|
+
attribute :config, :hash
|
14
|
+
attribute :project_name, :string
|
15
|
+
attribute :tune_started_at
|
16
|
+
attribute :y_true
|
17
|
+
attribute :x_true
|
18
|
+
|
19
|
+
def run_trial(trial)
|
20
|
+
config = deep_merge_defaults(self.config.clone)
|
21
|
+
suggest_parameters(trial, config)
|
22
|
+
model.fit
|
23
|
+
yield model
|
24
|
+
end
|
25
|
+
|
26
|
+
def configure_callbacks
|
27
|
+
raise "Subclasses fof Tuner::Adapter::BaseAdapter must define #configure_callbacks"
|
28
|
+
end
|
29
|
+
|
30
|
+
def suggest_parameters(trial, config)
|
31
|
+
defaults.keys.each do |param_name|
|
32
|
+
param_value = suggest_parameter(trial, param_name, config)
|
33
|
+
model.hyperparameters.send("#{param_name}=", param_value)
|
34
|
+
end
|
35
|
+
end
|
36
|
+
|
37
|
+
def deep_merge_defaults(config)
|
38
|
+
defaults.deep_merge(config) do |_key, default_value, config_value|
|
39
|
+
if default_value.is_a?(Hash) && config_value.is_a?(Hash)
|
40
|
+
default_value.merge(config_value)
|
41
|
+
else
|
42
|
+
config_value
|
43
|
+
end
|
44
|
+
end
|
45
|
+
end
|
46
|
+
|
47
|
+
def suggest_parameter(trial, param_name, config)
|
48
|
+
param_config = config[param_name]
|
49
|
+
min = param_config[:min]
|
50
|
+
max = param_config[:max]
|
51
|
+
log = param_config[:log]
|
52
|
+
|
53
|
+
if log
|
54
|
+
trial.suggest_loguniform(param_name.to_s, min, max)
|
55
|
+
else
|
56
|
+
trial.suggest_uniform(param_name.to_s, min, max)
|
57
|
+
end
|
58
|
+
end
|
59
|
+
end
|
60
|
+
end
|
61
|
+
end
|
62
|
+
end
|
63
|
+
end
|
@@ -0,0 +1,50 @@
|
|
1
|
+
require_relative "base_adapter"
|
2
|
+
|
3
|
+
module EasyML
|
4
|
+
module Core
|
5
|
+
class Tuner
|
6
|
+
module Adapters
|
7
|
+
class XGBoostAdapter < BaseAdapter
|
8
|
+
include GlueGun::DSL
|
9
|
+
|
10
|
+
def defaults
|
11
|
+
{
|
12
|
+
learning_rate: {
|
13
|
+
min: 0.001,
|
14
|
+
max: 0.1,
|
15
|
+
log: true
|
16
|
+
},
|
17
|
+
n_estimators: {
|
18
|
+
min: 100,
|
19
|
+
max: 1_000
|
20
|
+
},
|
21
|
+
max_depth: {
|
22
|
+
min: 2,
|
23
|
+
max: 20
|
24
|
+
}
|
25
|
+
}
|
26
|
+
end
|
27
|
+
|
28
|
+
def configure_callbacks
|
29
|
+
model.customize_callbacks do |callbacks|
|
30
|
+
return unless callbacks.present?
|
31
|
+
|
32
|
+
wandb_callback = callbacks.detect { |cb| cb.class == Wandb::XGBoostCallback }
|
33
|
+
return unless wandb_callback.present?
|
34
|
+
|
35
|
+
wandb_callback.project_name = "#{wandb_callback.project_name}_#{tune_started_at.strftime("%Y_%m_%d_%H_%M_%S")}"
|
36
|
+
wandb_callback.custom_loggers = [
|
37
|
+
lambda do |booster, _epoch, _hist|
|
38
|
+
dtrain = model.send(:preprocess, x_true, y_true)
|
39
|
+
y_pred = booster.predict(dtrain)
|
40
|
+
metrics = model.evaluate(y_pred: y_pred, y_true: y_true, x_true: x_true)
|
41
|
+
Wandb.log(metrics)
|
42
|
+
end
|
43
|
+
]
|
44
|
+
end
|
45
|
+
end
|
46
|
+
end
|
47
|
+
end
|
48
|
+
end
|
49
|
+
end
|
50
|
+
end
|
@@ -0,0 +1,105 @@
|
|
1
|
+
require "optuna"
|
2
|
+
require_relative "tuner/adapters"
|
3
|
+
|
4
|
+
module EasyML
|
5
|
+
module Core
|
6
|
+
class Tuner
|
7
|
+
include GlueGun::DSL
|
8
|
+
|
9
|
+
attribute :model
|
10
|
+
attribute :dataset
|
11
|
+
attribute :project_name, :string
|
12
|
+
attribute :task, :string
|
13
|
+
attribute :config, :hash, default: {}
|
14
|
+
attribute :metrics, :array
|
15
|
+
attribute :objective, :string
|
16
|
+
attribute :n_trials, default: 100
|
17
|
+
attribute :callbacks, :array
|
18
|
+
attr_accessor :study, :results
|
19
|
+
|
20
|
+
dependency :adapter, lazy: false do |dep|
|
21
|
+
dep.option :xgboost do |opt|
|
22
|
+
opt.set_class Adapters::XGBoostAdapter
|
23
|
+
opt.bind_attribute :model
|
24
|
+
opt.bind_attribute :config
|
25
|
+
opt.bind_attribute :project_name
|
26
|
+
opt.bind_attribute :tune_started_at
|
27
|
+
opt.bind_attribute :y_true
|
28
|
+
end
|
29
|
+
|
30
|
+
dep.when do |_dep|
|
31
|
+
case model
|
32
|
+
when EasyML::Core::Models::XGBoost, EasyML::Models::XGBoost
|
33
|
+
{ option: :xgboost }
|
34
|
+
end
|
35
|
+
end
|
36
|
+
end
|
37
|
+
|
38
|
+
def loggers(_study, trial)
|
39
|
+
return unless trial.state.name == "FAIL"
|
40
|
+
|
41
|
+
raise "Trial failed: Stopping optimization."
|
42
|
+
end
|
43
|
+
|
44
|
+
def tune
|
45
|
+
set_defaults!
|
46
|
+
|
47
|
+
@study = Optuna::Study.new
|
48
|
+
@results = []
|
49
|
+
model.task = task
|
50
|
+
x_true, y_true = model.dataset.test(split_ys: true)
|
51
|
+
tune_started_at = EST.now
|
52
|
+
adapter = pick_adapter.new(model: model, config: config, tune_started_at: tune_started_at, y_true: y_true,
|
53
|
+
x_true: x_true)
|
54
|
+
adapter.configure_callbacks
|
55
|
+
|
56
|
+
@study.optimize(n_trials: n_trials, callbacks: [method(:loggers)]) do |trial|
|
57
|
+
run_metrics = tune_once(trial, x_true, y_true, adapter)
|
58
|
+
|
59
|
+
result = if model.evaluator.present?
|
60
|
+
if model.evaluator_metric.present?
|
61
|
+
run_metrics[model.evaluator_metric]
|
62
|
+
else
|
63
|
+
run_metrics[:custom]
|
64
|
+
end
|
65
|
+
else
|
66
|
+
run_metrics[objective.to_sym]
|
67
|
+
end
|
68
|
+
@results.push(result)
|
69
|
+
result
|
70
|
+
rescue StandardError => e
|
71
|
+
puts "Optuna failed with: #{e.message}"
|
72
|
+
end
|
73
|
+
|
74
|
+
raise "Optuna study failed" unless @study.respond_to?(:best_trial)
|
75
|
+
|
76
|
+
@study.best_trial.params
|
77
|
+
end
|
78
|
+
|
79
|
+
def pick_adapter
|
80
|
+
case model
|
81
|
+
when EasyML::Core::Models::XGBoost, EasyML::Models::XGBoost
|
82
|
+
Adapters::XGBoostAdapter
|
83
|
+
end
|
84
|
+
end
|
85
|
+
|
86
|
+
def tune_once(trial, x_true, y_true, adapter)
|
87
|
+
adapter.run_trial(trial) do |model|
|
88
|
+
y_pred = model.predict(y_true)
|
89
|
+
model.metrics = metrics
|
90
|
+
model.evaluate(y_pred: y_pred, y_true: y_true, x_true: x_true)
|
91
|
+
end
|
92
|
+
end
|
93
|
+
|
94
|
+
def set_defaults!
|
95
|
+
unless task.present?
|
96
|
+
self.task = model.task
|
97
|
+
raise ArgumentError, "EasyML::Core::Tuner requires task (regression or classification)" unless task.present?
|
98
|
+
end
|
99
|
+
raise ArgumentError, "Objectives required for EasyML::Core::Tuner" unless objective.present?
|
100
|
+
|
101
|
+
self.metrics = EasyML::Core::Model.new(task: task).allowed_metrics if metrics.nil? || metrics.empty?
|
102
|
+
end
|
103
|
+
end
|
104
|
+
end
|
105
|
+
end
|
@@ -0,0 +1,24 @@
|
|
1
|
+
require "carrierwave"
|
2
|
+
|
3
|
+
module EasyML
|
4
|
+
module Core
|
5
|
+
module Uploaders
|
6
|
+
class ModelUploader < CarrierWave::Uploader::Base
|
7
|
+
# Choose storage type
|
8
|
+
if Rails.env.production?
|
9
|
+
storage :fog
|
10
|
+
else
|
11
|
+
storage :file
|
12
|
+
end
|
13
|
+
|
14
|
+
def store_dir
|
15
|
+
"easy_ml_models/#{model.name}"
|
16
|
+
end
|
17
|
+
|
18
|
+
def extension_allowlist
|
19
|
+
%w[bin model json]
|
20
|
+
end
|
21
|
+
end
|
22
|
+
end
|
23
|
+
end
|
24
|
+
end
|