easy_ml 0.1.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. checksums.yaml +7 -0
  2. data/README.md +270 -0
  3. data/Rakefile +12 -0
  4. data/app/models/easy_ml/model.rb +59 -0
  5. data/app/models/easy_ml/models/xgboost.rb +9 -0
  6. data/app/models/easy_ml/models.rb +5 -0
  7. data/lib/easy_ml/core/model.rb +29 -0
  8. data/lib/easy_ml/core/model_core.rb +181 -0
  9. data/lib/easy_ml/core/model_evaluator.rb +137 -0
  10. data/lib/easy_ml/core/models/hyperparameters/base.rb +34 -0
  11. data/lib/easy_ml/core/models/hyperparameters/xgboost.rb +19 -0
  12. data/lib/easy_ml/core/models/hyperparameters.rb +8 -0
  13. data/lib/easy_ml/core/models/xgboost.rb +10 -0
  14. data/lib/easy_ml/core/models/xgboost_core.rb +220 -0
  15. data/lib/easy_ml/core/models.rb +10 -0
  16. data/lib/easy_ml/core/tuner/adapters/base_adapter.rb +63 -0
  17. data/lib/easy_ml/core/tuner/adapters/xgboost_adapter.rb +50 -0
  18. data/lib/easy_ml/core/tuner/adapters.rb +10 -0
  19. data/lib/easy_ml/core/tuner.rb +105 -0
  20. data/lib/easy_ml/core/uploaders/model_uploader.rb +24 -0
  21. data/lib/easy_ml/core/uploaders.rb +7 -0
  22. data/lib/easy_ml/core.rb +9 -0
  23. data/lib/easy_ml/core_ext/pathname.rb +9 -0
  24. data/lib/easy_ml/core_ext.rb +5 -0
  25. data/lib/easy_ml/data/dataloader.rb +6 -0
  26. data/lib/easy_ml/data/dataset/data/preprocessor/statistics.json +31 -0
  27. data/lib/easy_ml/data/dataset/data/sample_info.json +1 -0
  28. data/lib/easy_ml/data/dataset/dataset/files/sample_info.json +1 -0
  29. data/lib/easy_ml/data/dataset/splits/file_split.rb +140 -0
  30. data/lib/easy_ml/data/dataset/splits/in_memory_split.rb +49 -0
  31. data/lib/easy_ml/data/dataset/splits/split.rb +98 -0
  32. data/lib/easy_ml/data/dataset/splits.rb +11 -0
  33. data/lib/easy_ml/data/dataset/splitters/date_splitter.rb +43 -0
  34. data/lib/easy_ml/data/dataset/splitters.rb +9 -0
  35. data/lib/easy_ml/data/dataset.rb +430 -0
  36. data/lib/easy_ml/data/datasource/datasource_factory.rb +60 -0
  37. data/lib/easy_ml/data/datasource/file_datasource.rb +40 -0
  38. data/lib/easy_ml/data/datasource/merged_datasource.rb +64 -0
  39. data/lib/easy_ml/data/datasource/polars_datasource.rb +41 -0
  40. data/lib/easy_ml/data/datasource/s3_datasource.rb +89 -0
  41. data/lib/easy_ml/data/datasource.rb +33 -0
  42. data/lib/easy_ml/data/preprocessor/preprocessor.rb +205 -0
  43. data/lib/easy_ml/data/preprocessor/simple_imputer.rb +403 -0
  44. data/lib/easy_ml/data/preprocessor/utils.rb +17 -0
  45. data/lib/easy_ml/data/preprocessor.rb +238 -0
  46. data/lib/easy_ml/data/utils.rb +50 -0
  47. data/lib/easy_ml/data.rb +8 -0
  48. data/lib/easy_ml/deployment.rb +5 -0
  49. data/lib/easy_ml/engine.rb +26 -0
  50. data/lib/easy_ml/initializers/inflections.rb +4 -0
  51. data/lib/easy_ml/logging.rb +38 -0
  52. data/lib/easy_ml/railtie/generators/migration/migration_generator.rb +42 -0
  53. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_models.rb.tt +23 -0
  54. data/lib/easy_ml/support/age.rb +27 -0
  55. data/lib/easy_ml/support/est.rb +1 -0
  56. data/lib/easy_ml/support/file_rotate.rb +23 -0
  57. data/lib/easy_ml/support/git_ignorable.rb +66 -0
  58. data/lib/easy_ml/support/synced_directory.rb +134 -0
  59. data/lib/easy_ml/support/utc.rb +1 -0
  60. data/lib/easy_ml/support.rb +10 -0
  61. data/lib/easy_ml/trainer.rb +92 -0
  62. data/lib/easy_ml/transforms.rb +29 -0
  63. data/lib/easy_ml/version.rb +5 -0
  64. data/lib/easy_ml.rb +23 -0
  65. metadata +353 -0
@@ -0,0 +1,137 @@
1
+ module EasyML
2
+ module Core
3
+ class ModelEvaluator
4
+ require "numo/narray"
5
+
6
+ EVALUATORS = {
7
+ mean_absolute_error: lambda { |y_pred, y_true|
8
+ (Numo::DFloat.cast(y_pred) - Numo::DFloat.cast(y_true)).abs.mean
9
+ },
10
+ mean_squared_error: lambda { |y_pred, y_true|
11
+ ((Numo::DFloat.cast(y_pred) - Numo::DFloat.cast(y_true))**2).mean
12
+ },
13
+ root_mean_squared_error: lambda { |y_pred, y_true|
14
+ Math.sqrt(((Numo::DFloat.cast(y_pred) - Numo::DFloat.cast(y_true))**2).mean)
15
+ },
16
+ r2_score: lambda { |y_pred, y_true|
17
+ # Convert inputs to Numo::DFloat for numerical operations
18
+ y_true = Numo::DFloat.cast(y_true)
19
+ y_pred = Numo::DFloat.cast(y_pred)
20
+
21
+ # Calculate the mean of the true values
22
+ mean_y = y_true.mean
23
+
24
+ # Calculate Total Sum of Squares (SS_tot)
25
+ ss_tot = ((y_true - mean_y)**2).sum
26
+
27
+ # Calculate Residual Sum of Squares (SS_res)
28
+ ss_res = ((y_true - y_pred)**2).sum
29
+
30
+ # Handle the edge case where SS_tot is zero
31
+ if ss_tot.zero?
32
+ if ss_res.zero?
33
+ # Perfect prediction when both SS_tot and SS_res are zero
34
+ 1.0
35
+ else
36
+ # Undefined R² when SS_tot is zero but SS_res is not
37
+ Float::NAN
38
+ end
39
+ else
40
+ # Calculate R²
41
+ 1 - (ss_res / ss_tot)
42
+ end
43
+ },
44
+ accuracy_score: lambda { |y_pred, y_true|
45
+ y_pred = Numo::Int32.cast(y_pred)
46
+ y_true = Numo::Int32.cast(y_true)
47
+ y_pred.eq(y_true).count_true.to_f / y_pred.size
48
+ },
49
+ precision_score: lambda { |y_pred, y_true|
50
+ y_pred = Numo::Int32.cast(y_pred)
51
+ y_true = Numo::Int32.cast(y_true)
52
+ true_positives = (y_pred.eq(1) & y_true.eq(1)).count_true
53
+ predicted_positives = y_pred.eq(1).count_true
54
+ return 0 if predicted_positives == 0
55
+
56
+ true_positives.to_f / predicted_positives
57
+ },
58
+ recall_score: lambda { |y_pred, y_true|
59
+ y_pred = Numo::Int32.cast(y_pred)
60
+ y_true = Numo::Int32.cast(y_true)
61
+ true_positives = (y_pred.eq(1) & y_true.eq(1)).count_true
62
+ actual_positives = y_true.eq(1).count_true
63
+ true_positives.to_f / actual_positives
64
+ },
65
+ f1_score: lambda { |y_pred, y_true|
66
+ precision = EVALUATORS[:precision_score].call(y_pred, y_true) || 0
67
+ recall = EVALUATORS[:recall_score].call(y_pred, y_true) || 0
68
+ return 0 unless (precision + recall) > 0
69
+
70
+ 2 * (precision * recall) / (precision + recall)
71
+ }
72
+ }
73
+
74
+ class << self
75
+ def evaluate(model: nil, y_pred: nil, y_true: nil, x_true: nil, evaluator: nil)
76
+ y_pred = normalize_input(y_pred)
77
+ y_true = normalize_input(y_true)
78
+ check_size(y_pred, y_true)
79
+
80
+ metrics_results = {}
81
+
82
+ model.metrics.each do |metric|
83
+ if metric.is_a?(Module) || metric.is_a?(Class)
84
+ unless metric.respond_to?(:evaluate)
85
+ raise "Metric #{metric} must respond to #evaluate in order to be used as a custom evaluator"
86
+ end
87
+
88
+ metrics_results[metric.name] = metric.evaluate(y_pred, y_true)
89
+ elsif EVALUATORS.key?(metric.to_sym)
90
+ metrics_results[metric.to_sym] =
91
+ EVALUATORS[metric.to_sym].call(y_pred, y_true)
92
+ end
93
+ end
94
+
95
+ if evaluator.present?
96
+ if evaluator.is_a?(Class)
97
+ response = evaluator.new.evaluate(y_pred: y_pred, y_true: y_true, x_true: x_true)
98
+ elsif evaluator.respond_to?(:evaluate)
99
+ response = evaluator.evaluate(y_pred: y_pred, y_true: y_true, x_true: x_true)
100
+ elsif evaluator.respond_to?(:call)
101
+ response = evaluator.call(y_pred: y_pred, y_true: y_true, x_true: x_true)
102
+ else
103
+ raise "Don't know how to use CustomEvaluator. Must be a class that responds to evaluate or lambda"
104
+ end
105
+
106
+ if response.is_a?(Hash)
107
+ metrics_results.merge!(response)
108
+ else
109
+ metrics_results[:custom] = response
110
+ end
111
+ end
112
+
113
+ metrics_results
114
+ end
115
+
116
+ def check_size(y_pred, y_true)
117
+ raise ArgumentError, "Different sizes" if y_true.size != y_pred.size
118
+ end
119
+
120
+ def normalize_input(input)
121
+ case input
122
+ when Polars::DataFrame
123
+ if input.columns.count > 1
124
+ raise ArgumentError, "Don't know how to evaluate input with multiple columns: #{input}"
125
+ end
126
+
127
+ normalize_input(input[input.columns.first])
128
+ when Polars::Series, Array
129
+ Numo::DFloat.cast(input)
130
+ else
131
+ raise ArgumentError, "Don't know how to evaluate model with y_pred type #{input.class}"
132
+ end
133
+ end
134
+ end
135
+ end
136
+ end
137
+ end
@@ -0,0 +1,34 @@
1
+ module EasyML
2
+ module Models
3
+ module Hyperparameters
4
+ class Base
5
+ include GlueGun::DSL
6
+
7
+ attribute :learning_rate, :float, default: 0.01
8
+ attribute :max_iterations, :integer, default: 100
9
+ attribute :batch_size, :integer, default: 32
10
+ attribute :regularization, :float, default: 0.0001
11
+
12
+ def to_h
13
+ attributes
14
+ end
15
+
16
+ def merge(other)
17
+ return self if other.nil?
18
+
19
+ other_hash = other.is_a?(Hyperparameters) ? other.to_h : other
20
+ merged_hash = to_h.merge(other_hash)
21
+ self.class.new(**merged_hash)
22
+ end
23
+
24
+ def [](key)
25
+ send(key) if respond_to?(key)
26
+ end
27
+
28
+ def []=(key, value)
29
+ send("#{key}=", value) if respond_to?("#{key}=")
30
+ end
31
+ end
32
+ end
33
+ end
34
+ end
@@ -0,0 +1,19 @@
1
+ module EasyML
2
+ module Models
3
+ module Hyperparameters
4
+ class XGBoost < Base
5
+ include GlueGun::DSL
6
+
7
+ attribute :learning_rate, :float, default: 0.1
8
+ attribute :max_depth, :integer, default: 6
9
+ attribute :n_estimators, :integer, default: 100
10
+ attribute :booster, :string, default: "gbtree"
11
+ attribute :objective, :string, default: "reg:squarederror"
12
+
13
+ validates :objective,
14
+ inclusion: { in: %w[binary:logistic binary:hinge multi:softmax multi:softprob reg:squarederror
15
+ reg:logistic] }
16
+ end
17
+ end
18
+ end
19
+ end
@@ -0,0 +1,8 @@
1
+ module EasyML
2
+ module Models
3
+ module Hyperparameters
4
+ require_relative "hyperparameters/base"
5
+ require_relative "hyperparameters/xgboost"
6
+ end
7
+ end
8
+ end
@@ -0,0 +1,10 @@
1
+ require_relative "xgboost_core"
2
+ module EasyML
3
+ module Core
4
+ module Models
5
+ class XGBoost < EasyML::Core::Model
6
+ include XGBoostCore
7
+ end
8
+ end
9
+ end
10
+ end
@@ -0,0 +1,220 @@
1
+ require "wandb"
2
+ module EasyML
3
+ module Core
4
+ module Models
5
+ module XGBoostCore
6
+ OBJECTIVES = {
7
+ classification: {
8
+ binary: %w[binary:logistic binary:hinge],
9
+ multi_class: %w[multi:softmax multi:softprob]
10
+ },
11
+ regression: %w[reg:squarederror reg:logistic]
12
+ }
13
+
14
+ def self.included(base)
15
+ base.class_eval do
16
+ attribute :evaluator
17
+
18
+ dependency :callbacks, { array: true } do |dep|
19
+ dep.option :wandb do |opt|
20
+ opt.set_class Wandb::XGBoostCallback
21
+ opt.bind_attribute :log_model, default: false
22
+ opt.bind_attribute :log_feature_importance, default: true
23
+ opt.bind_attribute :importance_type, default: "gain"
24
+ opt.bind_attribute :define_metric, default: true
25
+ opt.bind_attribute :project_name
26
+ end
27
+ end
28
+
29
+ dependency :hyperparameters do |dep|
30
+ dep.set_class EasyML::Models::Hyperparameters::XGBoost
31
+ dep.bind_attribute :batch_size, default: 32
32
+ dep.bind_attribute :learning_rate, default: 1.1
33
+ dep.bind_attribute :max_depth, default: 6
34
+ dep.bind_attribute :n_estimators, default: 100
35
+ dep.bind_attribute :booster, default: "gbtree"
36
+ dep.bind_attribute :objective, default: "reg:squarederror"
37
+ end
38
+ end
39
+ end
40
+
41
+ attr_accessor :model, :booster
42
+
43
+ def predict(xs)
44
+ raise "No trained model! Train a model before calling predict" unless @booster.present?
45
+ raise "Cannot predict on nil — XGBoost" if xs.nil?
46
+
47
+ y_pred = @booster.predict(preprocess(xs))
48
+
49
+ case task.to_sym
50
+ when :classification
51
+ to_classification(y_pred)
52
+ else
53
+ y_pred
54
+ end
55
+ end
56
+
57
+ def predict_proba(data)
58
+ dmat = DMatrix.new(data)
59
+ y_pred = @booster.predict(dmat)
60
+
61
+ if y_pred.first.is_a?(Array)
62
+ # multiple classes
63
+ y_pred
64
+ else
65
+ y_pred.map { |v| [1 - v, v] }
66
+ end
67
+ end
68
+
69
+ def load(path = nil)
70
+ path ||= file
71
+ path = path&.file&.file if path.class.ancestors.include?(CarrierWave::Uploader::Base)
72
+
73
+ raise "No existing model at #{path}" unless File.exist?(path)
74
+
75
+ initialize_model do
76
+ booster_class.new(params: hyperparameters.to_h, model_file: path)
77
+ end
78
+ end
79
+
80
+ def _save_model_file(path)
81
+ puts "XGBoost received path #{path}"
82
+ @booster.save_model(path)
83
+ end
84
+
85
+ def feature_importances
86
+ @model.booster.feature_names.zip(@model.feature_importances).to_h
87
+ end
88
+
89
+ def base_model
90
+ ::XGBoost
91
+ end
92
+
93
+ def customize_callbacks
94
+ yield callbacks
95
+ end
96
+
97
+ private
98
+
99
+ def booster_class
100
+ ::XGBoost::Booster
101
+ end
102
+
103
+ def d_matrix_class
104
+ ::XGBoost::DMatrix
105
+ end
106
+
107
+ def model_class
108
+ ::XGBoost::Model
109
+ end
110
+
111
+ def train
112
+ validate_objective
113
+
114
+ xs = xs.to_a.map(&:values)
115
+ ys = ys.to_a.map(&:values)
116
+ dtrain = d_matrix_class.new(xs, label: ys)
117
+ @model = base_model.train(hyperparameters.to_h, dtrain, callbacks: callbacks)
118
+ end
119
+
120
+ def train_in_batches
121
+ validate_objective
122
+
123
+ # Initialize the model with the first batch
124
+ @model = nil
125
+ @booster = nil
126
+ x_valid, y_valid = dataset.valid(split_ys: true)
127
+ x_train, y_train = dataset.train(split_ys: true)
128
+ fit_batch(x_train, y_train, x_valid, y_valid)
129
+ end
130
+
131
+ def _preprocess(df)
132
+ df.to_a.map do |row|
133
+ row.values.map do |value|
134
+ case value
135
+ when Time
136
+ value.to_i # Convert Time to Unix timestamp
137
+ when Date
138
+ value.to_time.to_i # Convert Date to Unix timestamp
139
+ when String
140
+ value
141
+ when TrueClass, FalseClass
142
+ value ? 1.0 : 0.0 # Convert booleans to 1.0 and 0.0
143
+ when Integer
144
+ value
145
+ else
146
+ value.to_f # Ensure everything else is converted to a float
147
+ end
148
+ end
149
+ end
150
+ end
151
+
152
+ def preprocess(xs, ys = nil)
153
+ column_names = xs.columns
154
+ xs = _preprocess(xs)
155
+ ys = ys.nil? ? nil : _preprocess(ys).flatten
156
+ kwargs = { label: ys }.compact
157
+ ::XGBoost::DMatrix.new(xs, **kwargs).tap do |dmat|
158
+ dmat.instance_variable_set(:@feature_names, column_names)
159
+ end
160
+ end
161
+
162
+ def initialize_model
163
+ @model = model_class.new(n_estimators: @hyperparameters.to_h.dig(:n_estimators))
164
+ @booster = yield
165
+ @model.instance_variable_set(:@booster, @booster)
166
+ end
167
+
168
+ def validate_objective
169
+ objective = hyperparameters.objective
170
+ unless task.present?
171
+ raise ArgumentError,
172
+ "cannot train model without task. Please specify either regression or classification (model.task = :regression)"
173
+ end
174
+
175
+ case task.to_sym
176
+ when :classification
177
+ _, ys = dataset.data(split_ys: true)
178
+ classification_type = ys[ys.columns.first].uniq.count <= 2 ? :binary : :multi_class
179
+ allowed_objectives = OBJECTIVES[:classification][classification_type]
180
+ else
181
+ allowed_objectives = OBJECTIVES[task.to_sym]
182
+ end
183
+ return if allowed_objectives.map(&:to_sym).include?(objective.to_sym)
184
+
185
+ raise ArgumentError,
186
+ "cannot use #{objective} for #{task} task. Allowed objectives are: #{allowed_objectives.join(", ")}"
187
+ end
188
+
189
+ def fit_batch(x_train, y_train, x_valid, y_valid)
190
+ d_train = preprocess(x_train, y_train)
191
+ d_valid = preprocess(x_valid, y_valid)
192
+
193
+ evals = [[d_train, "train"], [d_valid, "eval"]]
194
+
195
+ # # If this is the first batch, create the booster
196
+ if @booster.nil?
197
+ initialize_model do
198
+ base_model.train(@hyperparameters.to_h, d_train,
199
+ num_boost_round: @hyperparameters.to_h.dig("n_estimators"), evals: evals, callbacks: callbacks)
200
+ end
201
+ else
202
+ # Update the existing booster with the new batch
203
+ @model.update(d_train)
204
+ end
205
+ end
206
+
207
+ def to_classification(y_pred)
208
+ if y_pred.first.is_a?(Array)
209
+ # multiple classes
210
+ y_pred.map do |v|
211
+ v.map.with_index.max_by { |v2, _| v2 }.last
212
+ end
213
+ else
214
+ y_pred.map { |v| v > 0.5 ? 1 : 0 }
215
+ end
216
+ end
217
+ end
218
+ end
219
+ end
220
+ end
@@ -0,0 +1,10 @@
1
+ module EasyML
2
+ module Core
3
+ module Models
4
+ require_relative "models/hyperparameters"
5
+ require_relative "models/xgboost"
6
+
7
+ AVAILABLE_MODELS = [XGBoost]
8
+ end
9
+ end
10
+ end
@@ -0,0 +1,63 @@
1
+ module EasyML
2
+ module Core
3
+ class Tuner
4
+ module Adapters
5
+ class BaseAdapter
6
+ include GlueGun::DSL
7
+
8
+ def defaults
9
+ {}
10
+ end
11
+
12
+ attribute :model
13
+ attribute :config, :hash
14
+ attribute :project_name, :string
15
+ attribute :tune_started_at
16
+ attribute :y_true
17
+ attribute :x_true
18
+
19
+ def run_trial(trial)
20
+ config = deep_merge_defaults(self.config.clone)
21
+ suggest_parameters(trial, config)
22
+ model.fit
23
+ yield model
24
+ end
25
+
26
+ def configure_callbacks
27
+ raise "Subclasses fof Tuner::Adapter::BaseAdapter must define #configure_callbacks"
28
+ end
29
+
30
+ def suggest_parameters(trial, config)
31
+ defaults.keys.each do |param_name|
32
+ param_value = suggest_parameter(trial, param_name, config)
33
+ model.hyperparameters.send("#{param_name}=", param_value)
34
+ end
35
+ end
36
+
37
+ def deep_merge_defaults(config)
38
+ defaults.deep_merge(config) do |_key, default_value, config_value|
39
+ if default_value.is_a?(Hash) && config_value.is_a?(Hash)
40
+ default_value.merge(config_value)
41
+ else
42
+ config_value
43
+ end
44
+ end
45
+ end
46
+
47
+ def suggest_parameter(trial, param_name, config)
48
+ param_config = config[param_name]
49
+ min = param_config[:min]
50
+ max = param_config[:max]
51
+ log = param_config[:log]
52
+
53
+ if log
54
+ trial.suggest_loguniform(param_name.to_s, min, max)
55
+ else
56
+ trial.suggest_uniform(param_name.to_s, min, max)
57
+ end
58
+ end
59
+ end
60
+ end
61
+ end
62
+ end
63
+ end
@@ -0,0 +1,50 @@
1
+ require_relative "base_adapter"
2
+
3
+ module EasyML
4
+ module Core
5
+ class Tuner
6
+ module Adapters
7
+ class XGBoostAdapter < BaseAdapter
8
+ include GlueGun::DSL
9
+
10
+ def defaults
11
+ {
12
+ learning_rate: {
13
+ min: 0.001,
14
+ max: 0.1,
15
+ log: true
16
+ },
17
+ n_estimators: {
18
+ min: 100,
19
+ max: 1_000
20
+ },
21
+ max_depth: {
22
+ min: 2,
23
+ max: 20
24
+ }
25
+ }
26
+ end
27
+
28
+ def configure_callbacks
29
+ model.customize_callbacks do |callbacks|
30
+ return unless callbacks.present?
31
+
32
+ wandb_callback = callbacks.detect { |cb| cb.class == Wandb::XGBoostCallback }
33
+ return unless wandb_callback.present?
34
+
35
+ wandb_callback.project_name = "#{wandb_callback.project_name}_#{tune_started_at.strftime("%Y_%m_%d_%H_%M_%S")}"
36
+ wandb_callback.custom_loggers = [
37
+ lambda do |booster, _epoch, _hist|
38
+ dtrain = model.send(:preprocess, x_true, y_true)
39
+ y_pred = booster.predict(dtrain)
40
+ metrics = model.evaluate(y_pred: y_pred, y_true: y_true, x_true: x_true)
41
+ Wandb.log(metrics)
42
+ end
43
+ ]
44
+ end
45
+ end
46
+ end
47
+ end
48
+ end
49
+ end
50
+ end
@@ -0,0 +1,10 @@
1
+ module EasyML
2
+ module Core
3
+ class Tuner
4
+ module Adapters
5
+ require_relative "adapters/base_adapter"
6
+ require_relative "adapters/xgboost_adapter"
7
+ end
8
+ end
9
+ end
10
+ end
@@ -0,0 +1,105 @@
1
+ require "optuna"
2
+ require_relative "tuner/adapters"
3
+
4
+ module EasyML
5
+ module Core
6
+ class Tuner
7
+ include GlueGun::DSL
8
+
9
+ attribute :model
10
+ attribute :dataset
11
+ attribute :project_name, :string
12
+ attribute :task, :string
13
+ attribute :config, :hash, default: {}
14
+ attribute :metrics, :array
15
+ attribute :objective, :string
16
+ attribute :n_trials, default: 100
17
+ attribute :callbacks, :array
18
+ attr_accessor :study, :results
19
+
20
+ dependency :adapter, lazy: false do |dep|
21
+ dep.option :xgboost do |opt|
22
+ opt.set_class Adapters::XGBoostAdapter
23
+ opt.bind_attribute :model
24
+ opt.bind_attribute :config
25
+ opt.bind_attribute :project_name
26
+ opt.bind_attribute :tune_started_at
27
+ opt.bind_attribute :y_true
28
+ end
29
+
30
+ dep.when do |_dep|
31
+ case model
32
+ when EasyML::Core::Models::XGBoost, EasyML::Models::XGBoost
33
+ { option: :xgboost }
34
+ end
35
+ end
36
+ end
37
+
38
+ def loggers(_study, trial)
39
+ return unless trial.state.name == "FAIL"
40
+
41
+ raise "Trial failed: Stopping optimization."
42
+ end
43
+
44
+ def tune
45
+ set_defaults!
46
+
47
+ @study = Optuna::Study.new
48
+ @results = []
49
+ model.task = task
50
+ x_true, y_true = model.dataset.test(split_ys: true)
51
+ tune_started_at = EST.now
52
+ adapter = pick_adapter.new(model: model, config: config, tune_started_at: tune_started_at, y_true: y_true,
53
+ x_true: x_true)
54
+ adapter.configure_callbacks
55
+
56
+ @study.optimize(n_trials: n_trials, callbacks: [method(:loggers)]) do |trial|
57
+ run_metrics = tune_once(trial, x_true, y_true, adapter)
58
+
59
+ result = if model.evaluator.present?
60
+ if model.evaluator_metric.present?
61
+ run_metrics[model.evaluator_metric]
62
+ else
63
+ run_metrics[:custom]
64
+ end
65
+ else
66
+ run_metrics[objective.to_sym]
67
+ end
68
+ @results.push(result)
69
+ result
70
+ rescue StandardError => e
71
+ puts "Optuna failed with: #{e.message}"
72
+ end
73
+
74
+ raise "Optuna study failed" unless @study.respond_to?(:best_trial)
75
+
76
+ @study.best_trial.params
77
+ end
78
+
79
+ def pick_adapter
80
+ case model
81
+ when EasyML::Core::Models::XGBoost, EasyML::Models::XGBoost
82
+ Adapters::XGBoostAdapter
83
+ end
84
+ end
85
+
86
+ def tune_once(trial, x_true, y_true, adapter)
87
+ adapter.run_trial(trial) do |model|
88
+ y_pred = model.predict(y_true)
89
+ model.metrics = metrics
90
+ model.evaluate(y_pred: y_pred, y_true: y_true, x_true: x_true)
91
+ end
92
+ end
93
+
94
+ def set_defaults!
95
+ unless task.present?
96
+ self.task = model.task
97
+ raise ArgumentError, "EasyML::Core::Tuner requires task (regression or classification)" unless task.present?
98
+ end
99
+ raise ArgumentError, "Objectives required for EasyML::Core::Tuner" unless objective.present?
100
+
101
+ self.metrics = EasyML::Core::Model.new(task: task).allowed_metrics if metrics.nil? || metrics.empty?
102
+ end
103
+ end
104
+ end
105
+ end
@@ -0,0 +1,24 @@
1
+ require "carrierwave"
2
+
3
+ module EasyML
4
+ module Core
5
+ module Uploaders
6
+ class ModelUploader < CarrierWave::Uploader::Base
7
+ # Choose storage type
8
+ if Rails.env.production?
9
+ storage :fog
10
+ else
11
+ storage :file
12
+ end
13
+
14
+ def store_dir
15
+ "easy_ml_models/#{model.name}"
16
+ end
17
+
18
+ def extension_allowlist
19
+ %w[bin model json]
20
+ end
21
+ end
22
+ end
23
+ end
24
+ end