easy_ml 0.1.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (65) hide show
  1. checksums.yaml +7 -0
  2. data/README.md +270 -0
  3. data/Rakefile +12 -0
  4. data/app/models/easy_ml/model.rb +59 -0
  5. data/app/models/easy_ml/models/xgboost.rb +9 -0
  6. data/app/models/easy_ml/models.rb +5 -0
  7. data/lib/easy_ml/core/model.rb +29 -0
  8. data/lib/easy_ml/core/model_core.rb +181 -0
  9. data/lib/easy_ml/core/model_evaluator.rb +137 -0
  10. data/lib/easy_ml/core/models/hyperparameters/base.rb +34 -0
  11. data/lib/easy_ml/core/models/hyperparameters/xgboost.rb +19 -0
  12. data/lib/easy_ml/core/models/hyperparameters.rb +8 -0
  13. data/lib/easy_ml/core/models/xgboost.rb +10 -0
  14. data/lib/easy_ml/core/models/xgboost_core.rb +220 -0
  15. data/lib/easy_ml/core/models.rb +10 -0
  16. data/lib/easy_ml/core/tuner/adapters/base_adapter.rb +63 -0
  17. data/lib/easy_ml/core/tuner/adapters/xgboost_adapter.rb +50 -0
  18. data/lib/easy_ml/core/tuner/adapters.rb +10 -0
  19. data/lib/easy_ml/core/tuner.rb +105 -0
  20. data/lib/easy_ml/core/uploaders/model_uploader.rb +24 -0
  21. data/lib/easy_ml/core/uploaders.rb +7 -0
  22. data/lib/easy_ml/core.rb +9 -0
  23. data/lib/easy_ml/core_ext/pathname.rb +9 -0
  24. data/lib/easy_ml/core_ext.rb +5 -0
  25. data/lib/easy_ml/data/dataloader.rb +6 -0
  26. data/lib/easy_ml/data/dataset/data/preprocessor/statistics.json +31 -0
  27. data/lib/easy_ml/data/dataset/data/sample_info.json +1 -0
  28. data/lib/easy_ml/data/dataset/dataset/files/sample_info.json +1 -0
  29. data/lib/easy_ml/data/dataset/splits/file_split.rb +140 -0
  30. data/lib/easy_ml/data/dataset/splits/in_memory_split.rb +49 -0
  31. data/lib/easy_ml/data/dataset/splits/split.rb +98 -0
  32. data/lib/easy_ml/data/dataset/splits.rb +11 -0
  33. data/lib/easy_ml/data/dataset/splitters/date_splitter.rb +43 -0
  34. data/lib/easy_ml/data/dataset/splitters.rb +9 -0
  35. data/lib/easy_ml/data/dataset.rb +430 -0
  36. data/lib/easy_ml/data/datasource/datasource_factory.rb +60 -0
  37. data/lib/easy_ml/data/datasource/file_datasource.rb +40 -0
  38. data/lib/easy_ml/data/datasource/merged_datasource.rb +64 -0
  39. data/lib/easy_ml/data/datasource/polars_datasource.rb +41 -0
  40. data/lib/easy_ml/data/datasource/s3_datasource.rb +89 -0
  41. data/lib/easy_ml/data/datasource.rb +33 -0
  42. data/lib/easy_ml/data/preprocessor/preprocessor.rb +205 -0
  43. data/lib/easy_ml/data/preprocessor/simple_imputer.rb +403 -0
  44. data/lib/easy_ml/data/preprocessor/utils.rb +17 -0
  45. data/lib/easy_ml/data/preprocessor.rb +238 -0
  46. data/lib/easy_ml/data/utils.rb +50 -0
  47. data/lib/easy_ml/data.rb +8 -0
  48. data/lib/easy_ml/deployment.rb +5 -0
  49. data/lib/easy_ml/engine.rb +26 -0
  50. data/lib/easy_ml/initializers/inflections.rb +4 -0
  51. data/lib/easy_ml/logging.rb +38 -0
  52. data/lib/easy_ml/railtie/generators/migration/migration_generator.rb +42 -0
  53. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_models.rb.tt +23 -0
  54. data/lib/easy_ml/support/age.rb +27 -0
  55. data/lib/easy_ml/support/est.rb +1 -0
  56. data/lib/easy_ml/support/file_rotate.rb +23 -0
  57. data/lib/easy_ml/support/git_ignorable.rb +66 -0
  58. data/lib/easy_ml/support/synced_directory.rb +134 -0
  59. data/lib/easy_ml/support/utc.rb +1 -0
  60. data/lib/easy_ml/support.rb +10 -0
  61. data/lib/easy_ml/trainer.rb +92 -0
  62. data/lib/easy_ml/transforms.rb +29 -0
  63. data/lib/easy_ml/version.rb +5 -0
  64. data/lib/easy_ml.rb +23 -0
  65. metadata +353 -0
@@ -0,0 +1,137 @@
1
+ module EasyML
2
+ module Core
3
+ class ModelEvaluator
4
+ require "numo/narray"
5
+
6
+ EVALUATORS = {
7
+ mean_absolute_error: lambda { |y_pred, y_true|
8
+ (Numo::DFloat.cast(y_pred) - Numo::DFloat.cast(y_true)).abs.mean
9
+ },
10
+ mean_squared_error: lambda { |y_pred, y_true|
11
+ ((Numo::DFloat.cast(y_pred) - Numo::DFloat.cast(y_true))**2).mean
12
+ },
13
+ root_mean_squared_error: lambda { |y_pred, y_true|
14
+ Math.sqrt(((Numo::DFloat.cast(y_pred) - Numo::DFloat.cast(y_true))**2).mean)
15
+ },
16
+ r2_score: lambda { |y_pred, y_true|
17
+ # Convert inputs to Numo::DFloat for numerical operations
18
+ y_true = Numo::DFloat.cast(y_true)
19
+ y_pred = Numo::DFloat.cast(y_pred)
20
+
21
+ # Calculate the mean of the true values
22
+ mean_y = y_true.mean
23
+
24
+ # Calculate Total Sum of Squares (SS_tot)
25
+ ss_tot = ((y_true - mean_y)**2).sum
26
+
27
+ # Calculate Residual Sum of Squares (SS_res)
28
+ ss_res = ((y_true - y_pred)**2).sum
29
+
30
+ # Handle the edge case where SS_tot is zero
31
+ if ss_tot.zero?
32
+ if ss_res.zero?
33
+ # Perfect prediction when both SS_tot and SS_res are zero
34
+ 1.0
35
+ else
36
+ # Undefined R² when SS_tot is zero but SS_res is not
37
+ Float::NAN
38
+ end
39
+ else
40
+ # Calculate R²
41
+ 1 - (ss_res / ss_tot)
42
+ end
43
+ },
44
+ accuracy_score: lambda { |y_pred, y_true|
45
+ y_pred = Numo::Int32.cast(y_pred)
46
+ y_true = Numo::Int32.cast(y_true)
47
+ y_pred.eq(y_true).count_true.to_f / y_pred.size
48
+ },
49
+ precision_score: lambda { |y_pred, y_true|
50
+ y_pred = Numo::Int32.cast(y_pred)
51
+ y_true = Numo::Int32.cast(y_true)
52
+ true_positives = (y_pred.eq(1) & y_true.eq(1)).count_true
53
+ predicted_positives = y_pred.eq(1).count_true
54
+ return 0 if predicted_positives == 0
55
+
56
+ true_positives.to_f / predicted_positives
57
+ },
58
+ recall_score: lambda { |y_pred, y_true|
59
+ y_pred = Numo::Int32.cast(y_pred)
60
+ y_true = Numo::Int32.cast(y_true)
61
+ true_positives = (y_pred.eq(1) & y_true.eq(1)).count_true
62
+ actual_positives = y_true.eq(1).count_true
63
+ true_positives.to_f / actual_positives
64
+ },
65
+ f1_score: lambda { |y_pred, y_true|
66
+ precision = EVALUATORS[:precision_score].call(y_pred, y_true) || 0
67
+ recall = EVALUATORS[:recall_score].call(y_pred, y_true) || 0
68
+ return 0 unless (precision + recall) > 0
69
+
70
+ 2 * (precision * recall) / (precision + recall)
71
+ }
72
+ }
73
+
74
+ class << self
75
+ def evaluate(model: nil, y_pred: nil, y_true: nil, x_true: nil, evaluator: nil)
76
+ y_pred = normalize_input(y_pred)
77
+ y_true = normalize_input(y_true)
78
+ check_size(y_pred, y_true)
79
+
80
+ metrics_results = {}
81
+
82
+ model.metrics.each do |metric|
83
+ if metric.is_a?(Module) || metric.is_a?(Class)
84
+ unless metric.respond_to?(:evaluate)
85
+ raise "Metric #{metric} must respond to #evaluate in order to be used as a custom evaluator"
86
+ end
87
+
88
+ metrics_results[metric.name] = metric.evaluate(y_pred, y_true)
89
+ elsif EVALUATORS.key?(metric.to_sym)
90
+ metrics_results[metric.to_sym] =
91
+ EVALUATORS[metric.to_sym].call(y_pred, y_true)
92
+ end
93
+ end
94
+
95
+ if evaluator.present?
96
+ if evaluator.is_a?(Class)
97
+ response = evaluator.new.evaluate(y_pred: y_pred, y_true: y_true, x_true: x_true)
98
+ elsif evaluator.respond_to?(:evaluate)
99
+ response = evaluator.evaluate(y_pred: y_pred, y_true: y_true, x_true: x_true)
100
+ elsif evaluator.respond_to?(:call)
101
+ response = evaluator.call(y_pred: y_pred, y_true: y_true, x_true: x_true)
102
+ else
103
+ raise "Don't know how to use CustomEvaluator. Must be a class that responds to evaluate or lambda"
104
+ end
105
+
106
+ if response.is_a?(Hash)
107
+ metrics_results.merge!(response)
108
+ else
109
+ metrics_results[:custom] = response
110
+ end
111
+ end
112
+
113
+ metrics_results
114
+ end
115
+
116
+ def check_size(y_pred, y_true)
117
+ raise ArgumentError, "Different sizes" if y_true.size != y_pred.size
118
+ end
119
+
120
+ def normalize_input(input)
121
+ case input
122
+ when Polars::DataFrame
123
+ if input.columns.count > 1
124
+ raise ArgumentError, "Don't know how to evaluate input with multiple columns: #{input}"
125
+ end
126
+
127
+ normalize_input(input[input.columns.first])
128
+ when Polars::Series, Array
129
+ Numo::DFloat.cast(input)
130
+ else
131
+ raise ArgumentError, "Don't know how to evaluate model with y_pred type #{input.class}"
132
+ end
133
+ end
134
+ end
135
+ end
136
+ end
137
+ end
@@ -0,0 +1,34 @@
1
+ module EasyML
2
+ module Models
3
+ module Hyperparameters
4
+ class Base
5
+ include GlueGun::DSL
6
+
7
+ attribute :learning_rate, :float, default: 0.01
8
+ attribute :max_iterations, :integer, default: 100
9
+ attribute :batch_size, :integer, default: 32
10
+ attribute :regularization, :float, default: 0.0001
11
+
12
+ def to_h
13
+ attributes
14
+ end
15
+
16
+ def merge(other)
17
+ return self if other.nil?
18
+
19
+ other_hash = other.is_a?(Hyperparameters) ? other.to_h : other
20
+ merged_hash = to_h.merge(other_hash)
21
+ self.class.new(**merged_hash)
22
+ end
23
+
24
+ def [](key)
25
+ send(key) if respond_to?(key)
26
+ end
27
+
28
+ def []=(key, value)
29
+ send("#{key}=", value) if respond_to?("#{key}=")
30
+ end
31
+ end
32
+ end
33
+ end
34
+ end
@@ -0,0 +1,19 @@
1
+ module EasyML
2
+ module Models
3
+ module Hyperparameters
4
+ class XGBoost < Base
5
+ include GlueGun::DSL
6
+
7
+ attribute :learning_rate, :float, default: 0.1
8
+ attribute :max_depth, :integer, default: 6
9
+ attribute :n_estimators, :integer, default: 100
10
+ attribute :booster, :string, default: "gbtree"
11
+ attribute :objective, :string, default: "reg:squarederror"
12
+
13
+ validates :objective,
14
+ inclusion: { in: %w[binary:logistic binary:hinge multi:softmax multi:softprob reg:squarederror
15
+ reg:logistic] }
16
+ end
17
+ end
18
+ end
19
+ end
@@ -0,0 +1,8 @@
1
+ module EasyML
2
+ module Models
3
+ module Hyperparameters
4
+ require_relative "hyperparameters/base"
5
+ require_relative "hyperparameters/xgboost"
6
+ end
7
+ end
8
+ end
@@ -0,0 +1,10 @@
1
+ require_relative "xgboost_core"
2
+ module EasyML
3
+ module Core
4
+ module Models
5
+ class XGBoost < EasyML::Core::Model
6
+ include XGBoostCore
7
+ end
8
+ end
9
+ end
10
+ end
@@ -0,0 +1,220 @@
1
+ require "wandb"
2
+ module EasyML
3
+ module Core
4
+ module Models
5
+ module XGBoostCore
6
+ OBJECTIVES = {
7
+ classification: {
8
+ binary: %w[binary:logistic binary:hinge],
9
+ multi_class: %w[multi:softmax multi:softprob]
10
+ },
11
+ regression: %w[reg:squarederror reg:logistic]
12
+ }
13
+
14
+ def self.included(base)
15
+ base.class_eval do
16
+ attribute :evaluator
17
+
18
+ dependency :callbacks, { array: true } do |dep|
19
+ dep.option :wandb do |opt|
20
+ opt.set_class Wandb::XGBoostCallback
21
+ opt.bind_attribute :log_model, default: false
22
+ opt.bind_attribute :log_feature_importance, default: true
23
+ opt.bind_attribute :importance_type, default: "gain"
24
+ opt.bind_attribute :define_metric, default: true
25
+ opt.bind_attribute :project_name
26
+ end
27
+ end
28
+
29
+ dependency :hyperparameters do |dep|
30
+ dep.set_class EasyML::Models::Hyperparameters::XGBoost
31
+ dep.bind_attribute :batch_size, default: 32
32
+ dep.bind_attribute :learning_rate, default: 1.1
33
+ dep.bind_attribute :max_depth, default: 6
34
+ dep.bind_attribute :n_estimators, default: 100
35
+ dep.bind_attribute :booster, default: "gbtree"
36
+ dep.bind_attribute :objective, default: "reg:squarederror"
37
+ end
38
+ end
39
+ end
40
+
41
+ attr_accessor :model, :booster
42
+
43
+ def predict(xs)
44
+ raise "No trained model! Train a model before calling predict" unless @booster.present?
45
+ raise "Cannot predict on nil — XGBoost" if xs.nil?
46
+
47
+ y_pred = @booster.predict(preprocess(xs))
48
+
49
+ case task.to_sym
50
+ when :classification
51
+ to_classification(y_pred)
52
+ else
53
+ y_pred
54
+ end
55
+ end
56
+
57
+ def predict_proba(data)
58
+ dmat = DMatrix.new(data)
59
+ y_pred = @booster.predict(dmat)
60
+
61
+ if y_pred.first.is_a?(Array)
62
+ # multiple classes
63
+ y_pred
64
+ else
65
+ y_pred.map { |v| [1 - v, v] }
66
+ end
67
+ end
68
+
69
+ def load(path = nil)
70
+ path ||= file
71
+ path = path&.file&.file if path.class.ancestors.include?(CarrierWave::Uploader::Base)
72
+
73
+ raise "No existing model at #{path}" unless File.exist?(path)
74
+
75
+ initialize_model do
76
+ booster_class.new(params: hyperparameters.to_h, model_file: path)
77
+ end
78
+ end
79
+
80
+ def _save_model_file(path)
81
+ puts "XGBoost received path #{path}"
82
+ @booster.save_model(path)
83
+ end
84
+
85
+ def feature_importances
86
+ @model.booster.feature_names.zip(@model.feature_importances).to_h
87
+ end
88
+
89
+ def base_model
90
+ ::XGBoost
91
+ end
92
+
93
+ def customize_callbacks
94
+ yield callbacks
95
+ end
96
+
97
+ private
98
+
99
+ def booster_class
100
+ ::XGBoost::Booster
101
+ end
102
+
103
+ def d_matrix_class
104
+ ::XGBoost::DMatrix
105
+ end
106
+
107
+ def model_class
108
+ ::XGBoost::Model
109
+ end
110
+
111
+ def train
112
+ validate_objective
113
+
114
+ xs = xs.to_a.map(&:values)
115
+ ys = ys.to_a.map(&:values)
116
+ dtrain = d_matrix_class.new(xs, label: ys)
117
+ @model = base_model.train(hyperparameters.to_h, dtrain, callbacks: callbacks)
118
+ end
119
+
120
+ def train_in_batches
121
+ validate_objective
122
+
123
+ # Initialize the model with the first batch
124
+ @model = nil
125
+ @booster = nil
126
+ x_valid, y_valid = dataset.valid(split_ys: true)
127
+ x_train, y_train = dataset.train(split_ys: true)
128
+ fit_batch(x_train, y_train, x_valid, y_valid)
129
+ end
130
+
131
+ def _preprocess(df)
132
+ df.to_a.map do |row|
133
+ row.values.map do |value|
134
+ case value
135
+ when Time
136
+ value.to_i # Convert Time to Unix timestamp
137
+ when Date
138
+ value.to_time.to_i # Convert Date to Unix timestamp
139
+ when String
140
+ value
141
+ when TrueClass, FalseClass
142
+ value ? 1.0 : 0.0 # Convert booleans to 1.0 and 0.0
143
+ when Integer
144
+ value
145
+ else
146
+ value.to_f # Ensure everything else is converted to a float
147
+ end
148
+ end
149
+ end
150
+ end
151
+
152
+ def preprocess(xs, ys = nil)
153
+ column_names = xs.columns
154
+ xs = _preprocess(xs)
155
+ ys = ys.nil? ? nil : _preprocess(ys).flatten
156
+ kwargs = { label: ys }.compact
157
+ ::XGBoost::DMatrix.new(xs, **kwargs).tap do |dmat|
158
+ dmat.instance_variable_set(:@feature_names, column_names)
159
+ end
160
+ end
161
+
162
+ def initialize_model
163
+ @model = model_class.new(n_estimators: @hyperparameters.to_h.dig(:n_estimators))
164
+ @booster = yield
165
+ @model.instance_variable_set(:@booster, @booster)
166
+ end
167
+
168
+ def validate_objective
169
+ objective = hyperparameters.objective
170
+ unless task.present?
171
+ raise ArgumentError,
172
+ "cannot train model without task. Please specify either regression or classification (model.task = :regression)"
173
+ end
174
+
175
+ case task.to_sym
176
+ when :classification
177
+ _, ys = dataset.data(split_ys: true)
178
+ classification_type = ys[ys.columns.first].uniq.count <= 2 ? :binary : :multi_class
179
+ allowed_objectives = OBJECTIVES[:classification][classification_type]
180
+ else
181
+ allowed_objectives = OBJECTIVES[task.to_sym]
182
+ end
183
+ return if allowed_objectives.map(&:to_sym).include?(objective.to_sym)
184
+
185
+ raise ArgumentError,
186
+ "cannot use #{objective} for #{task} task. Allowed objectives are: #{allowed_objectives.join(", ")}"
187
+ end
188
+
189
+ def fit_batch(x_train, y_train, x_valid, y_valid)
190
+ d_train = preprocess(x_train, y_train)
191
+ d_valid = preprocess(x_valid, y_valid)
192
+
193
+ evals = [[d_train, "train"], [d_valid, "eval"]]
194
+
195
+ # # If this is the first batch, create the booster
196
+ if @booster.nil?
197
+ initialize_model do
198
+ base_model.train(@hyperparameters.to_h, d_train,
199
+ num_boost_round: @hyperparameters.to_h.dig("n_estimators"), evals: evals, callbacks: callbacks)
200
+ end
201
+ else
202
+ # Update the existing booster with the new batch
203
+ @model.update(d_train)
204
+ end
205
+ end
206
+
207
+ def to_classification(y_pred)
208
+ if y_pred.first.is_a?(Array)
209
+ # multiple classes
210
+ y_pred.map do |v|
211
+ v.map.with_index.max_by { |v2, _| v2 }.last
212
+ end
213
+ else
214
+ y_pred.map { |v| v > 0.5 ? 1 : 0 }
215
+ end
216
+ end
217
+ end
218
+ end
219
+ end
220
+ end
@@ -0,0 +1,10 @@
1
+ module EasyML
2
+ module Core
3
+ module Models
4
+ require_relative "models/hyperparameters"
5
+ require_relative "models/xgboost"
6
+
7
+ AVAILABLE_MODELS = [XGBoost]
8
+ end
9
+ end
10
+ end
@@ -0,0 +1,63 @@
1
+ module EasyML
2
+ module Core
3
+ class Tuner
4
+ module Adapters
5
+ class BaseAdapter
6
+ include GlueGun::DSL
7
+
8
+ def defaults
9
+ {}
10
+ end
11
+
12
+ attribute :model
13
+ attribute :config, :hash
14
+ attribute :project_name, :string
15
+ attribute :tune_started_at
16
+ attribute :y_true
17
+ attribute :x_true
18
+
19
+ def run_trial(trial)
20
+ config = deep_merge_defaults(self.config.clone)
21
+ suggest_parameters(trial, config)
22
+ model.fit
23
+ yield model
24
+ end
25
+
26
+ def configure_callbacks
27
+ raise "Subclasses fof Tuner::Adapter::BaseAdapter must define #configure_callbacks"
28
+ end
29
+
30
+ def suggest_parameters(trial, config)
31
+ defaults.keys.each do |param_name|
32
+ param_value = suggest_parameter(trial, param_name, config)
33
+ model.hyperparameters.send("#{param_name}=", param_value)
34
+ end
35
+ end
36
+
37
+ def deep_merge_defaults(config)
38
+ defaults.deep_merge(config) do |_key, default_value, config_value|
39
+ if default_value.is_a?(Hash) && config_value.is_a?(Hash)
40
+ default_value.merge(config_value)
41
+ else
42
+ config_value
43
+ end
44
+ end
45
+ end
46
+
47
+ def suggest_parameter(trial, param_name, config)
48
+ param_config = config[param_name]
49
+ min = param_config[:min]
50
+ max = param_config[:max]
51
+ log = param_config[:log]
52
+
53
+ if log
54
+ trial.suggest_loguniform(param_name.to_s, min, max)
55
+ else
56
+ trial.suggest_uniform(param_name.to_s, min, max)
57
+ end
58
+ end
59
+ end
60
+ end
61
+ end
62
+ end
63
+ end
@@ -0,0 +1,50 @@
1
+ require_relative "base_adapter"
2
+
3
+ module EasyML
4
+ module Core
5
+ class Tuner
6
+ module Adapters
7
+ class XGBoostAdapter < BaseAdapter
8
+ include GlueGun::DSL
9
+
10
+ def defaults
11
+ {
12
+ learning_rate: {
13
+ min: 0.001,
14
+ max: 0.1,
15
+ log: true
16
+ },
17
+ n_estimators: {
18
+ min: 100,
19
+ max: 1_000
20
+ },
21
+ max_depth: {
22
+ min: 2,
23
+ max: 20
24
+ }
25
+ }
26
+ end
27
+
28
+ def configure_callbacks
29
+ model.customize_callbacks do |callbacks|
30
+ return unless callbacks.present?
31
+
32
+ wandb_callback = callbacks.detect { |cb| cb.class == Wandb::XGBoostCallback }
33
+ return unless wandb_callback.present?
34
+
35
+ wandb_callback.project_name = "#{wandb_callback.project_name}_#{tune_started_at.strftime("%Y_%m_%d_%H_%M_%S")}"
36
+ wandb_callback.custom_loggers = [
37
+ lambda do |booster, _epoch, _hist|
38
+ dtrain = model.send(:preprocess, x_true, y_true)
39
+ y_pred = booster.predict(dtrain)
40
+ metrics = model.evaluate(y_pred: y_pred, y_true: y_true, x_true: x_true)
41
+ Wandb.log(metrics)
42
+ end
43
+ ]
44
+ end
45
+ end
46
+ end
47
+ end
48
+ end
49
+ end
50
+ end
@@ -0,0 +1,10 @@
1
+ module EasyML
2
+ module Core
3
+ class Tuner
4
+ module Adapters
5
+ require_relative "adapters/base_adapter"
6
+ require_relative "adapters/xgboost_adapter"
7
+ end
8
+ end
9
+ end
10
+ end
@@ -0,0 +1,105 @@
1
+ require "optuna"
2
+ require_relative "tuner/adapters"
3
+
4
+ module EasyML
5
+ module Core
6
+ class Tuner
7
+ include GlueGun::DSL
8
+
9
+ attribute :model
10
+ attribute :dataset
11
+ attribute :project_name, :string
12
+ attribute :task, :string
13
+ attribute :config, :hash, default: {}
14
+ attribute :metrics, :array
15
+ attribute :objective, :string
16
+ attribute :n_trials, default: 100
17
+ attribute :callbacks, :array
18
+ attr_accessor :study, :results
19
+
20
+ dependency :adapter, lazy: false do |dep|
21
+ dep.option :xgboost do |opt|
22
+ opt.set_class Adapters::XGBoostAdapter
23
+ opt.bind_attribute :model
24
+ opt.bind_attribute :config
25
+ opt.bind_attribute :project_name
26
+ opt.bind_attribute :tune_started_at
27
+ opt.bind_attribute :y_true
28
+ end
29
+
30
+ dep.when do |_dep|
31
+ case model
32
+ when EasyML::Core::Models::XGBoost, EasyML::Models::XGBoost
33
+ { option: :xgboost }
34
+ end
35
+ end
36
+ end
37
+
38
+ def loggers(_study, trial)
39
+ return unless trial.state.name == "FAIL"
40
+
41
+ raise "Trial failed: Stopping optimization."
42
+ end
43
+
44
+ def tune
45
+ set_defaults!
46
+
47
+ @study = Optuna::Study.new
48
+ @results = []
49
+ model.task = task
50
+ x_true, y_true = model.dataset.test(split_ys: true)
51
+ tune_started_at = EST.now
52
+ adapter = pick_adapter.new(model: model, config: config, tune_started_at: tune_started_at, y_true: y_true,
53
+ x_true: x_true)
54
+ adapter.configure_callbacks
55
+
56
+ @study.optimize(n_trials: n_trials, callbacks: [method(:loggers)]) do |trial|
57
+ run_metrics = tune_once(trial, x_true, y_true, adapter)
58
+
59
+ result = if model.evaluator.present?
60
+ if model.evaluator_metric.present?
61
+ run_metrics[model.evaluator_metric]
62
+ else
63
+ run_metrics[:custom]
64
+ end
65
+ else
66
+ run_metrics[objective.to_sym]
67
+ end
68
+ @results.push(result)
69
+ result
70
+ rescue StandardError => e
71
+ puts "Optuna failed with: #{e.message}"
72
+ end
73
+
74
+ raise "Optuna study failed" unless @study.respond_to?(:best_trial)
75
+
76
+ @study.best_trial.params
77
+ end
78
+
79
+ def pick_adapter
80
+ case model
81
+ when EasyML::Core::Models::XGBoost, EasyML::Models::XGBoost
82
+ Adapters::XGBoostAdapter
83
+ end
84
+ end
85
+
86
+ def tune_once(trial, x_true, y_true, adapter)
87
+ adapter.run_trial(trial) do |model|
88
+ y_pred = model.predict(y_true)
89
+ model.metrics = metrics
90
+ model.evaluate(y_pred: y_pred, y_true: y_true, x_true: x_true)
91
+ end
92
+ end
93
+
94
+ def set_defaults!
95
+ unless task.present?
96
+ self.task = model.task
97
+ raise ArgumentError, "EasyML::Core::Tuner requires task (regression or classification)" unless task.present?
98
+ end
99
+ raise ArgumentError, "Objectives required for EasyML::Core::Tuner" unless objective.present?
100
+
101
+ self.metrics = EasyML::Core::Model.new(task: task).allowed_metrics if metrics.nil? || metrics.empty?
102
+ end
103
+ end
104
+ end
105
+ end
@@ -0,0 +1,24 @@
1
+ require "carrierwave"
2
+
3
+ module EasyML
4
+ module Core
5
+ module Uploaders
6
+ class ModelUploader < CarrierWave::Uploader::Base
7
+ # Choose storage type
8
+ if Rails.env.production?
9
+ storage :fog
10
+ else
11
+ storage :file
12
+ end
13
+
14
+ def store_dir
15
+ "easy_ml_models/#{model.name}"
16
+ end
17
+
18
+ def extension_allowlist
19
+ %w[bin model json]
20
+ end
21
+ end
22
+ end
23
+ end
24
+ end