easy_ml 0.1.1 → 0.1.2
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/lib/easy_ml/core/model.rb +1 -0
- data/lib/easy_ml/core/model_core.rb +38 -38
- data/lib/easy_ml/core/models/xgboost_core.rb +1 -1
- data/lib/easy_ml/data/preprocessor/simple_imputer.rb +0 -1
- data/lib/easy_ml/railtie/generators/migration/migration_generator.rb +19 -11
- data/lib/easy_ml/support/file_rotate.rb +1 -1
- data/lib/easy_ml/version.rb +1 -1
- data/lib/easy_ml.rb +0 -1
- metadata +46 -5
- data/lib/easy_ml/trainer.rb +0 -92
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: b8bf32b84b8e3991a5c9f2f25355dda8e8b685a6016c659f0ec169e509054e80
|
4
|
+
data.tar.gz: 27141e14e2b7aab481decb0e2a0189973ae18d86635bd9a586a9083fdb22312d
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: eb876892ad129afe383aa74acfbae05d75ecd98e6564369ac55b3ab82d6b9703cb3c5f6d7764315092dd13cc644df95fbba6ab21390f667273fec2aa2eb56cad
|
7
|
+
data.tar.gz: 866195fa9983ffce3ff08eaf7832df758aeccc0d8f6faef1fc93faa2c06fe6d78a9d4a4742370b50e7710e48b86d08d0b7af98fd16d14812d39b92cc95e8e355
|
data/lib/easy_ml/core/model.rb
CHANGED
@@ -21,30 +21,6 @@ module EasyML
|
|
21
21
|
end
|
22
22
|
end
|
23
23
|
|
24
|
-
def dataset_is_a_dataset?
|
25
|
-
return if dataset.nil?
|
26
|
-
return if dataset.class.ancestors.include?(EasyML::Data::Dataset)
|
27
|
-
|
28
|
-
errors.add(:dataset, "Must be a subclass of EasyML::Dataset")
|
29
|
-
end
|
30
|
-
|
31
|
-
def validate_any_metrics?
|
32
|
-
return if metrics.any?
|
33
|
-
|
34
|
-
errors.add(:metrics, "Must include at least one metric. Allowed metrics are #{allowed_metrics.join(", ")}")
|
35
|
-
end
|
36
|
-
|
37
|
-
def validate_metrics_for_task
|
38
|
-
nonsensical_metrics = metrics.select do |metric|
|
39
|
-
allowed_metrics.exclude?(metric)
|
40
|
-
end
|
41
|
-
|
42
|
-
return unless nonsensical_metrics.any?
|
43
|
-
|
44
|
-
errors.add(:metrics,
|
45
|
-
"cannot use metrics: #{nonsensical_metrics.join(", ")} for task #{task}. Allowed metrics are: #{allowed_metrics.join(", ")}")
|
46
|
-
end
|
47
|
-
|
48
24
|
def fit(x_train: nil, y_train: nil, x_valid: nil, y_valid: nil)
|
49
25
|
if x_train.nil?
|
50
26
|
dataset.refresh!
|
@@ -55,16 +31,6 @@ module EasyML
|
|
55
31
|
@is_fit = true
|
56
32
|
end
|
57
33
|
|
58
|
-
def decode_labels(ys, col: nil)
|
59
|
-
dataset.decode_labels(ys, col: col)
|
60
|
-
end
|
61
|
-
|
62
|
-
def evaluate(y_pred: nil, y_true: nil, x_true: nil, evaluator: nil)
|
63
|
-
evaluator ||= self.evaluator
|
64
|
-
EasyML::Core::ModelEvaluator.evaluate(model: self, y_pred: y_pred, y_true: y_true, x_true: x_true,
|
65
|
-
evaluator: evaluator)
|
66
|
-
end
|
67
|
-
|
68
34
|
def predict(xs)
|
69
35
|
raise NotImplementedError, "Subclasses must implement predict method"
|
70
36
|
end
|
@@ -82,6 +48,16 @@ module EasyML
|
|
82
48
|
save_model_file
|
83
49
|
end
|
84
50
|
|
51
|
+
def decode_labels(ys, col: nil)
|
52
|
+
dataset.decode_labels(ys, col: col)
|
53
|
+
end
|
54
|
+
|
55
|
+
def evaluate(y_pred: nil, y_true: nil, x_true: nil, evaluator: nil)
|
56
|
+
evaluator ||= self.evaluator
|
57
|
+
EasyML::Core::ModelEvaluator.evaluate(model: self, y_pred: y_pred, y_true: y_true, x_true: x_true,
|
58
|
+
evaluator: evaluator)
|
59
|
+
end
|
60
|
+
|
85
61
|
def save_model_file
|
86
62
|
raise "No trained model! Need to train model before saving (call model.fit)" unless fit?
|
87
63
|
|
@@ -116,13 +92,13 @@ module EasyML
|
|
116
92
|
end
|
117
93
|
|
118
94
|
def cleanup!
|
119
|
-
[
|
95
|
+
[carrierwave_dir, model_dir].each do |dir|
|
120
96
|
EasyML::FileRotate.new(dir, []).cleanup(extension_allowlist)
|
121
97
|
end
|
122
98
|
end
|
123
99
|
|
124
100
|
def cleanup
|
125
|
-
[
|
101
|
+
[carrierwave_dir, model_dir].each do |dir|
|
126
102
|
EasyML::FileRotate.new(dir, files_to_keep).cleanup(extension_allowlist)
|
127
103
|
end
|
128
104
|
end
|
@@ -133,7 +109,7 @@ module EasyML
|
|
133
109
|
|
134
110
|
private
|
135
111
|
|
136
|
-
def
|
112
|
+
def carrierwave_dir
|
137
113
|
return unless file.path.present?
|
138
114
|
|
139
115
|
File.dirname(file.path).split("/")[0..-2].join("/")
|
@@ -172,10 +148,34 @@ module EasyML
|
|
172
148
|
end
|
173
149
|
|
174
150
|
def files_to_keep
|
175
|
-
Dir.glob(File.join(
|
151
|
+
Dir.glob(File.join(carrierwave_dir, "**/*")).select { |f| File.file?(f) }.sort_by do |filename|
|
176
152
|
Time.parse(filename.split("/").last.gsub(/\D/, ""))
|
177
153
|
end.reverse.take(5)
|
178
154
|
end
|
155
|
+
|
156
|
+
def dataset_is_a_dataset?
|
157
|
+
return if dataset.nil?
|
158
|
+
return if dataset.class.ancestors.include?(EasyML::Data::Dataset)
|
159
|
+
|
160
|
+
errors.add(:dataset, "Must be a subclass of EasyML::Dataset")
|
161
|
+
end
|
162
|
+
|
163
|
+
def validate_any_metrics?
|
164
|
+
return if metrics.any?
|
165
|
+
|
166
|
+
errors.add(:metrics, "Must include at least one metric. Allowed metrics are #{allowed_metrics.join(", ")}")
|
167
|
+
end
|
168
|
+
|
169
|
+
def validate_metrics_for_task
|
170
|
+
nonsensical_metrics = metrics.select do |metric|
|
171
|
+
allowed_metrics.exclude?(metric)
|
172
|
+
end
|
173
|
+
|
174
|
+
return unless nonsensical_metrics.any?
|
175
|
+
|
176
|
+
errors.add(:metrics,
|
177
|
+
"cannot use metrics: #{nonsensical_metrics.join(", ")} for task #{task}. Allowed metrics are: #{allowed_metrics.join(", ")}")
|
178
|
+
end
|
179
179
|
end
|
180
180
|
end
|
181
181
|
end
|
@@ -155,7 +155,7 @@ module EasyML
|
|
155
155
|
ys = ys.nil? ? nil : _preprocess(ys).flatten
|
156
156
|
kwargs = { label: ys }.compact
|
157
157
|
::XGBoost::DMatrix.new(xs, **kwargs).tap do |dmat|
|
158
|
-
dmat.
|
158
|
+
dmat.feature_names = column_names
|
159
159
|
end
|
160
160
|
end
|
161
161
|
|
@@ -1,4 +1,3 @@
|
|
1
|
-
# lib/railtie/generators/migration/migration_generator.rb
|
2
1
|
require "rails/generators"
|
3
2
|
require "rails/generators/active_record/migration"
|
4
3
|
|
@@ -13,12 +12,7 @@ module EasyML
|
|
13
12
|
source_root File.expand_path("../../templates/migration", __dir__)
|
14
13
|
|
15
14
|
# Define the migration name
|
16
|
-
desc "Generates
|
17
|
-
|
18
|
-
# Define the migration name; can be customized if needed
|
19
|
-
def self.migration_name
|
20
|
-
"create_easy_ml_models"
|
21
|
-
end
|
15
|
+
desc "Generates migrations for EasyMLModel, Dataset, and TunerRun"
|
22
16
|
|
23
17
|
# Specify the next migration number
|
24
18
|
def self.next_migration_number(dirname)
|
@@ -31,10 +25,24 @@ module EasyML
|
|
31
25
|
end
|
32
26
|
end
|
33
27
|
|
34
|
-
# Generate the migration
|
35
|
-
def
|
36
|
-
|
37
|
-
|
28
|
+
# Generate the migration files using the templates
|
29
|
+
def create_migration_files
|
30
|
+
create_easy_ml_models_migration
|
31
|
+
end
|
32
|
+
|
33
|
+
private
|
34
|
+
|
35
|
+
# Generate the migration file for EasyMLModel using the template
|
36
|
+
def create_easy_ml_models_migration
|
37
|
+
migration_template(
|
38
|
+
"create_easy_ml_models.rb.tt",
|
39
|
+
"db/migrate/create_easy_ml_models.rb"
|
40
|
+
)
|
41
|
+
end
|
42
|
+
|
43
|
+
# Get the next migration number
|
44
|
+
def next_migration_number
|
45
|
+
self.class.next_migration_number(Rails.root.join("db/migrate"))
|
38
46
|
end
|
39
47
|
end
|
40
48
|
end
|
@@ -16,7 +16,7 @@ module EasyML
|
|
16
16
|
files_to_check.each do |file|
|
17
17
|
FileUtils.chown_R(`whoami`.chomp, "staff", file)
|
18
18
|
FileUtils.chmod_R(0o777, file)
|
19
|
-
File.delete(file)
|
19
|
+
File.delete(file) if @files_to_keep.exclude?(file) && File.exist?(file)
|
20
20
|
end
|
21
21
|
end
|
22
22
|
end
|
data/lib/easy_ml/version.rb
CHANGED
data/lib/easy_ml.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: easy_ml
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.1.
|
4
|
+
version: 0.1.2
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Brett Shollenberger
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2024-10-
|
11
|
+
date: 2024-10-18 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: activerecord
|
@@ -58,20 +58,34 @@ dependencies:
|
|
58
58
|
- - "<"
|
59
59
|
- !ruby/object:Gem::Version
|
60
60
|
version: '4'
|
61
|
+
- !ruby/object:Gem::Dependency
|
62
|
+
name: fog
|
63
|
+
requirement: !ruby/object:Gem::Requirement
|
64
|
+
requirements:
|
65
|
+
- - "~>"
|
66
|
+
- !ruby/object:Gem::Version
|
67
|
+
version: '1.42'
|
68
|
+
type: :runtime
|
69
|
+
prerelease: false
|
70
|
+
version_requirements: !ruby/object:Gem::Requirement
|
71
|
+
requirements:
|
72
|
+
- - "~>"
|
73
|
+
- !ruby/object:Gem::Version
|
74
|
+
version: '1.42'
|
61
75
|
- !ruby/object:Gem::Dependency
|
62
76
|
name: fog-aws
|
63
77
|
requirement: !ruby/object:Gem::Requirement
|
64
78
|
requirements:
|
65
79
|
- - "~>"
|
66
80
|
- !ruby/object:Gem::Version
|
67
|
-
version: '
|
81
|
+
version: '2.0'
|
68
82
|
type: :runtime
|
69
83
|
prerelease: false
|
70
84
|
version_requirements: !ruby/object:Gem::Requirement
|
71
85
|
requirements:
|
72
86
|
- - "~>"
|
73
87
|
- !ruby/object:Gem::Version
|
74
|
-
version: '
|
88
|
+
version: '2.0'
|
75
89
|
- !ruby/object:Gem::Dependency
|
76
90
|
name: glue_gun_dsl
|
77
91
|
requirement: !ruby/object:Gem::Requirement
|
@@ -156,6 +170,34 @@ dependencies:
|
|
156
170
|
- - ">="
|
157
171
|
- !ruby/object:Gem::Version
|
158
172
|
version: '0'
|
173
|
+
- !ruby/object:Gem::Dependency
|
174
|
+
name: wandb
|
175
|
+
requirement: !ruby/object:Gem::Requirement
|
176
|
+
requirements:
|
177
|
+
- - "~>"
|
178
|
+
- !ruby/object:Gem::Version
|
179
|
+
version: 0.1.6
|
180
|
+
type: :runtime
|
181
|
+
prerelease: false
|
182
|
+
version_requirements: !ruby/object:Gem::Requirement
|
183
|
+
requirements:
|
184
|
+
- - "~>"
|
185
|
+
- !ruby/object:Gem::Version
|
186
|
+
version: 0.1.6
|
187
|
+
- !ruby/object:Gem::Dependency
|
188
|
+
name: xgb
|
189
|
+
requirement: !ruby/object:Gem::Requirement
|
190
|
+
requirements:
|
191
|
+
- - ">="
|
192
|
+
- !ruby/object:Gem::Version
|
193
|
+
version: '0'
|
194
|
+
type: :runtime
|
195
|
+
prerelease: false
|
196
|
+
version_requirements: !ruby/object:Gem::Requirement
|
197
|
+
requirements:
|
198
|
+
- - ">="
|
199
|
+
- !ruby/object:Gem::Version
|
200
|
+
version: '0'
|
159
201
|
- !ruby/object:Gem::Dependency
|
160
202
|
name: annotate
|
161
203
|
requirement: !ruby/object:Gem::Requirement
|
@@ -321,7 +363,6 @@ files:
|
|
321
363
|
- lib/easy_ml/support/git_ignorable.rb
|
322
364
|
- lib/easy_ml/support/synced_directory.rb
|
323
365
|
- lib/easy_ml/support/utc.rb
|
324
|
-
- lib/easy_ml/trainer.rb
|
325
366
|
- lib/easy_ml/transforms.rb
|
326
367
|
- lib/easy_ml/version.rb
|
327
368
|
homepage: https://github.com/brettshollenberger/easy_ml
|
data/lib/easy_ml/trainer.rb
DELETED
@@ -1,92 +0,0 @@
|
|
1
|
-
module EasyML
|
2
|
-
class Trainer
|
3
|
-
# include GlueGun::DSL
|
4
|
-
# include EasyML::Logging
|
5
|
-
|
6
|
-
# define_attr :verbose, default: false
|
7
|
-
# define_attr :root_dir do |root_dir|
|
8
|
-
# File.join(root_dir, "trainer")
|
9
|
-
# end
|
10
|
-
|
11
|
-
# define_config :dataset do |config|
|
12
|
-
# config.define_option :default do |option|
|
13
|
-
# option.set_class EasyML::Data::Dataset
|
14
|
-
# option.define_attr :root_dir
|
15
|
-
# option.define_attr :target
|
16
|
-
# option.define_attr :batch_size
|
17
|
-
# end
|
18
|
-
# end
|
19
|
-
|
20
|
-
# define_config :model do |config|
|
21
|
-
# config.define_option :default do |option|
|
22
|
-
# option.set_class EasyML::Model
|
23
|
-
# option.define_attr :root_dir
|
24
|
-
# option.define_attr :name
|
25
|
-
# option.define_attr :hyperparameters
|
26
|
-
# end
|
27
|
-
# end
|
28
|
-
|
29
|
-
# def train
|
30
|
-
# log_info("Starting training process") if verbose
|
31
|
-
|
32
|
-
# dataset.refresh!
|
33
|
-
|
34
|
-
# log_info("Fitting model") if verbose
|
35
|
-
# dataset.train(split_ys: true) do |xs, ys|
|
36
|
-
# model.fit(xs, ys)
|
37
|
-
# end
|
38
|
-
|
39
|
-
# log_info("Saving model") if verbose
|
40
|
-
# model.save
|
41
|
-
|
42
|
-
# log_info("Training completed") if verbose
|
43
|
-
# end
|
44
|
-
|
45
|
-
# def evaluate
|
46
|
-
# log_info("Starting evaluation process") if verbose
|
47
|
-
|
48
|
-
# results = {}
|
49
|
-
|
50
|
-
# %i[train test valid].each do |split|
|
51
|
-
# log_info("Evaluating on #{split} set") if verbose
|
52
|
-
# predictions = []
|
53
|
-
# actuals = []
|
54
|
-
|
55
|
-
# dataset.send(split, split_ys: true) do |xs, ys|
|
56
|
-
# batch_predictions = model.predict(xs)
|
57
|
-
# predictions.concat(batch_predictions.to_a)
|
58
|
-
# actuals.concat(ys.to_a)
|
59
|
-
# end
|
60
|
-
|
61
|
-
# results[split] = calculate_metrics(predictions, actuals)
|
62
|
-
# end
|
63
|
-
|
64
|
-
# log_info("Evaluation completed") if verbose
|
65
|
-
# results
|
66
|
-
# end
|
67
|
-
|
68
|
-
# private
|
69
|
-
|
70
|
-
# def calculate_metrics(predictions, actuals)
|
71
|
-
# # Implement your metric calculations here
|
72
|
-
# # This is a placeholder and should be replaced with actual metric calculations
|
73
|
-
# {
|
74
|
-
# mse: mean_squared_error(predictions, actuals),
|
75
|
-
# mae: mean_absolute_error(predictions, actuals),
|
76
|
-
# r2: r_squared(predictions, actuals)
|
77
|
-
# }
|
78
|
-
# end
|
79
|
-
|
80
|
-
# def mean_squared_error(predictions, actuals)
|
81
|
-
# # Implement MSE calculation
|
82
|
-
# end
|
83
|
-
|
84
|
-
# def mean_absolute_error(predictions, actuals)
|
85
|
-
# # Implement MAE calculation
|
86
|
-
# end
|
87
|
-
|
88
|
-
# def r_squared(predictions, actuals)
|
89
|
-
# # Implement R-squared calculation
|
90
|
-
# end
|
91
|
-
end
|
92
|
-
end
|