easy_ml 0.2.0.pre.rc55 → 0.2.0.pre.rc57
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/app/frontend/components/ScheduleModal.tsx +1 -1
- data/app/frontend/components/SearchableSelect.tsx +6 -8
- data/app/jobs/easy_ml/refresh_dataset_job.rb +3 -0
- data/app/models/easy_ml/column.rb +16 -3
- data/app/models/easy_ml/column_list.rb +6 -2
- data/app/models/easy_ml/dataset.rb +4 -5
- data/app/models/easy_ml/event.rb +5 -3
- data/app/models/easy_ml/model.rb +5 -2
- data/app/models/easy_ml/models/xgboost/evals_callback.rb +4 -3
- data/lib/easy_ml/core/evaluators/base_evaluator.rb +1 -1
- data/lib/easy_ml/core/evaluators/classification_evaluators.rb +9 -9
- data/lib/easy_ml/core/evaluators/regression_evaluators.rb +4 -4
- data/lib/easy_ml/core/model_evaluator.rb +18 -3
- data/lib/easy_ml/core/tuner.rb +22 -16
- data/lib/easy_ml/version.rb +1 -1
- data/public/easy_ml/assets/.vite/manifest.json +1 -1
- data/public/easy_ml/assets/assets/entrypoints/{Application.tsx-Dr3jVR78.js → Application.tsx-DTZ2348z.js} +28 -28
- data/public/easy_ml/assets/assets/entrypoints/{Application.tsx-Dr3jVR78.js.map → Application.tsx-DTZ2348z.js.map} +1 -1
- metadata +3 -3
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: e52412950fefc02e9b838930f132873c726440ebbc343159504d7d3287a39d05
|
4
|
+
data.tar.gz: 44ff18d1f1df78b542c8e536427189fce63d147e7e86623d219ed9b89c501ca7
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 1e543781fb426a6fa7fe6ad6f5b7c924bdab38d88ac8ad7288db3a24f683661b3745a6f2176c993899a9f9737af7e54dfa59cc439a71739d3e2d2d2d75714621
|
7
|
+
data.tar.gz: 3f012c5a3126eec7a69c3c11dd45017f7c2ded7a2bfd5e6e70bcaa388000b19e50d19ed15dc6b47786f61b698cc081e915abade7ece544a3c8a14d0a8f5c4696
|
@@ -91,7 +91,7 @@ export function ScheduleModal({ isOpen, onClose, onSave, initialData, metrics, t
|
|
91
91
|
day_of_month: initialData.retraining_job?.at?.day_of_month ?? 1
|
92
92
|
},
|
93
93
|
metric: initialData.retraining_job?.metric || (metrics[initialData.task]?.[0]?.value ?? ''),
|
94
|
-
threshold: initialData.retraining_job?.threshold
|
94
|
+
threshold: initialData.retraining_job?.threshold ?? (initialData.task === 'classification' ? 0.85 : 0.1),
|
95
95
|
tuner_config: initialData.retraining_job?.tuner_config ? {
|
96
96
|
n_trials: initialData.retraining_job.tuner_config.n_trials || 10,
|
97
97
|
config: {
|
@@ -60,7 +60,10 @@ export const SearchableSelect = forwardRef<HTMLButtonElement, SearchableSelectPr
|
|
60
60
|
}
|
61
61
|
}, [isOpen]);
|
62
62
|
|
63
|
-
const handleOptionClick = (optionValue: Option['value']) => {
|
63
|
+
const handleOptionClick = (optionValue: Option['value'], e: React.MouseEvent) => {
|
64
|
+
debugger;
|
65
|
+
e.preventDefault();
|
66
|
+
e.stopPropagation();
|
64
67
|
onChange(optionValue);
|
65
68
|
setIsOpen(false);
|
66
69
|
setSearchQuery('');
|
@@ -86,7 +89,7 @@ export const SearchableSelect = forwardRef<HTMLButtonElement, SearchableSelectPr
|
|
86
89
|
placeholder="Search..."
|
87
90
|
value={searchQuery}
|
88
91
|
onChange={(e) => setSearchQuery(e.target.value)}
|
89
|
-
|
92
|
+
onMouseDown={(e) => e.stopPropagation()}
|
90
93
|
/>
|
91
94
|
</div>
|
92
95
|
</div>
|
@@ -105,11 +108,7 @@ export const SearchableSelect = forwardRef<HTMLButtonElement, SearchableSelectPr
|
|
105
108
|
className={`w-full text-left px-4 py-2 hover:bg-gray-100 ${
|
106
109
|
option.value === value ? 'bg-blue-50' : ''
|
107
110
|
}`}
|
108
|
-
onMouseDown={(e) =>
|
109
|
-
e.preventDefault();
|
110
|
-
e.stopPropagation();
|
111
|
-
handleOptionClick(option.value);
|
112
|
-
}}
|
111
|
+
onMouseDown={(e) => handleOptionClick(option.value, e)}
|
113
112
|
>
|
114
113
|
<div className="flex items-center justify-between">
|
115
114
|
<span className="block font-medium">
|
@@ -140,7 +139,6 @@ export const SearchableSelect = forwardRef<HTMLButtonElement, SearchableSelectPr
|
|
140
139
|
type="button"
|
141
140
|
onMouseDown={(e) => {
|
142
141
|
e.preventDefault();
|
143
|
-
e.stopPropagation();
|
144
142
|
setIsOpen(!isOpen);
|
145
143
|
}}
|
146
144
|
className="w-full bg-white relative border border-gray-300 rounded-md shadow-sm pl-3 pr-10 py-2 text-left cursor-pointer focus:outline-none focus:ring-1 focus:ring-blue-500 focus:border-blue-500"
|
@@ -3,6 +3,8 @@ module EasyML
|
|
3
3
|
def perform(id)
|
4
4
|
begin
|
5
5
|
dataset = EasyML::Dataset.find(id)
|
6
|
+
return if dataset.workflow_status == :analyzing
|
7
|
+
|
6
8
|
puts "Refreshing dataset #{dataset.name}"
|
7
9
|
puts "Needs refresh? #{dataset.needs_refresh?}"
|
8
10
|
unless dataset.needs_refresh?
|
@@ -12,6 +14,7 @@ module EasyML
|
|
12
14
|
create_event(dataset, "started")
|
13
15
|
|
14
16
|
puts "Prepare! #{dataset.name}"
|
17
|
+
dataset.unlock!
|
15
18
|
dataset.prepare
|
16
19
|
if dataset.features.needs_fit.any?
|
17
20
|
dataset.fit_features(async: true)
|
@@ -30,7 +30,6 @@ module EasyML
|
|
30
30
|
validates :name, uniqueness: { scope: :dataset_id }
|
31
31
|
|
32
32
|
before_save :ensure_valid_datatype
|
33
|
-
after_create :set_date_column_if_date_splitter
|
34
33
|
after_save :handle_date_column_change
|
35
34
|
before_save :set_defaults
|
36
35
|
|
@@ -41,6 +40,18 @@ module EasyML
|
|
41
40
|
scope :datetime, -> { where(datatype: "datetime") }
|
42
41
|
scope :date_column, -> { where(is_date_column: true) }
|
43
42
|
|
43
|
+
def columns
|
44
|
+
[name].concat(virtual_columns)
|
45
|
+
end
|
46
|
+
|
47
|
+
def virtual_columns
|
48
|
+
if one_hot?
|
49
|
+
allowed_categories.map { |cat| "#{name}_#{cat}" }
|
50
|
+
else
|
51
|
+
[]
|
52
|
+
end
|
53
|
+
end
|
54
|
+
|
44
55
|
def datatype=(dtype)
|
45
56
|
write_attribute(:datatype, dtype)
|
46
57
|
write_attribute(:polars_datatype, dtype)
|
@@ -88,9 +99,11 @@ module EasyML
|
|
88
99
|
end
|
89
100
|
|
90
101
|
def allowed_categories
|
91
|
-
return
|
102
|
+
return [] unless one_hot?
|
103
|
+
stats = dataset.preprocessor.statistics
|
104
|
+
return [] if stats.nil? || stats.blank?
|
92
105
|
|
93
|
-
|
106
|
+
stats.dup.to_h.dig(name.to_sym, :allowed_categories).sort.concat(["other"])
|
94
107
|
end
|
95
108
|
|
96
109
|
def date_column?
|
@@ -1,6 +1,6 @@
|
|
1
1
|
module EasyML
|
2
2
|
module ColumnList
|
3
|
-
def sync
|
3
|
+
def sync(delete: true)
|
4
4
|
return unless dataset.schema.present?
|
5
5
|
|
6
6
|
EasyML::Column.transaction do
|
@@ -8,7 +8,11 @@ module EasyML
|
|
8
8
|
existing_columns = where(name: col_names)
|
9
9
|
import_new(col_names, existing_columns)
|
10
10
|
update_existing(existing_columns)
|
11
|
-
|
11
|
+
|
12
|
+
if delete
|
13
|
+
delete_missing(existing_columns)
|
14
|
+
end
|
15
|
+
|
12
16
|
if existing_columns.none? # Totally new dataset
|
13
17
|
dataset.after_create_columns
|
14
18
|
end
|
@@ -175,7 +175,6 @@ module EasyML
|
|
175
175
|
|
176
176
|
def actually_refresh
|
177
177
|
refreshing do
|
178
|
-
split_data
|
179
178
|
process_data
|
180
179
|
fully_reload
|
181
180
|
learn
|
@@ -273,10 +272,10 @@ module EasyML
|
|
273
272
|
raw.split_at.present? && raw.split_at < datasource.last_updated_at
|
274
273
|
end
|
275
274
|
|
276
|
-
def learn
|
275
|
+
def learn(delete: true)
|
277
276
|
learn_schema
|
278
277
|
learn_statistics
|
279
|
-
columns.sync
|
278
|
+
columns.sync(delete: delete)
|
280
279
|
end
|
281
280
|
|
282
281
|
def refreshing
|
@@ -399,7 +398,7 @@ module EasyML
|
|
399
398
|
|
400
399
|
# Learn will update columns, so if any features have been added
|
401
400
|
# since the last time columns were learned, we should re-learn the schema
|
402
|
-
learn if idx ==
|
401
|
+
learn(delete: false) if idx == 1 && needs_learn?(df)
|
403
402
|
df = apply_column_mask(df, inference: inference) unless all_columns
|
404
403
|
raise_on_nulls(df) if inference
|
405
404
|
df, = processed.split_features_targets(df, true, target) if split_ys
|
@@ -516,7 +515,7 @@ module EasyML
|
|
516
515
|
end
|
517
516
|
|
518
517
|
def drop_cols
|
519
|
-
@drop_cols ||= preloaded_columns.select(&:hidden).
|
518
|
+
@drop_cols ||= preloaded_columns.select(&:hidden).flat_map(&:columns)
|
520
519
|
end
|
521
520
|
|
522
521
|
def drop_if_null
|
data/app/models/easy_ml/event.rb
CHANGED
@@ -56,14 +56,16 @@ module EasyML
|
|
56
56
|
create_event(model, "failed", error)
|
57
57
|
end
|
58
58
|
|
59
|
+
def self.easy_ml_context(stacktrace)
|
60
|
+
stacktrace.select { |loc| loc.match?(/easy_ml/) }
|
61
|
+
end
|
62
|
+
|
59
63
|
def self.format_stacktrace(error)
|
60
64
|
return nil if error.nil?
|
61
65
|
|
62
66
|
topline = error.inspect
|
63
67
|
|
64
|
-
stacktrace = error.backtrace
|
65
|
-
loc.match?(/easy_ml/)
|
66
|
-
end
|
68
|
+
stacktrace = easy_ml_context(error.backtrace)
|
67
69
|
|
68
70
|
%(#{topline}
|
69
71
|
|
data/app/models/easy_ml/model.rb
CHANGED
@@ -354,15 +354,16 @@ module EasyML
|
|
354
354
|
dataset.decode_labels(ys, col: col)
|
355
355
|
end
|
356
356
|
|
357
|
-
def evaluate(y_pred: nil, y_true: nil, x_true: nil, evaluator: nil)
|
357
|
+
def evaluate(y_pred: nil, y_true: nil, x_true: nil, evaluator: nil, dataset: nil)
|
358
358
|
evaluator ||= self.evaluator
|
359
359
|
if y_pred.nil?
|
360
360
|
inputs = default_evaluation_inputs
|
361
361
|
y_pred = inputs[:y_pred]
|
362
362
|
y_true = inputs[:y_true]
|
363
363
|
x_true = inputs[:x_true]
|
364
|
+
dataset = inputs[:dataset]
|
364
365
|
end
|
365
|
-
EasyML::Core::ModelEvaluator.evaluate(model: self, y_pred: y_pred, y_true: y_true, x_true: x_true, evaluator: evaluator)
|
366
|
+
EasyML::Core::ModelEvaluator.evaluate(model: self, y_pred: y_pred, y_true: y_true, x_true: x_true, dataset: dataset, evaluator: evaluator)
|
366
367
|
end
|
367
368
|
|
368
369
|
def evaluator
|
@@ -524,11 +525,13 @@ module EasyML
|
|
524
525
|
|
525
526
|
def default_evaluation_inputs
|
526
527
|
x_true, y_true = dataset.test(split_ys: true)
|
528
|
+
ds = dataset.test(all_columns: true)
|
527
529
|
y_pred = predict(x_true)
|
528
530
|
{
|
529
531
|
x_true: x_true,
|
530
532
|
y_true: y_true,
|
531
533
|
y_pred: y_pred,
|
534
|
+
dataset: ds,
|
532
535
|
}
|
533
536
|
end
|
534
537
|
|
@@ -32,7 +32,7 @@ module EasyML
|
|
32
32
|
false
|
33
33
|
end
|
34
34
|
|
35
|
-
def
|
35
|
+
def test_dataset
|
36
36
|
if tuner.present?
|
37
37
|
[tuner.x_true, tuner.y_true]
|
38
38
|
else
|
@@ -46,11 +46,12 @@ module EasyML
|
|
46
46
|
log_frequency = 10
|
47
47
|
if epoch % log_frequency == 0
|
48
48
|
model.adapter.external_model = booster
|
49
|
-
x_true, y_true =
|
49
|
+
x_true, y_true = test_dataset
|
50
50
|
@preprocessed ||= model.preprocess(x_true)
|
51
51
|
y_pred = model.predict(@preprocessed)
|
52
|
+
dataset = model.dataset.test(all_columns: true)
|
52
53
|
|
53
|
-
metrics = model.evaluate(y_pred: y_pred, y_true: y_true, x_true: x_true)
|
54
|
+
metrics = model.evaluate(y_pred: y_pred, y_true: y_true, x_true: x_true, dataset: dataset)
|
54
55
|
Wandb.log(metrics)
|
55
56
|
end
|
56
57
|
|
@@ -32,7 +32,7 @@ module EasyML
|
|
32
32
|
end
|
33
33
|
|
34
34
|
# Instance methods that evaluators must implement
|
35
|
-
def evaluate(y_pred: nil, y_true: nil, x_true: nil)
|
35
|
+
def evaluate(y_pred: nil, y_true: nil, x_true: nil, dataset: nil)
|
36
36
|
raise NotImplementedError, "#{self.class} must implement #evaluate"
|
37
37
|
end
|
38
38
|
|
@@ -5,7 +5,7 @@ module EasyML
|
|
5
5
|
class AccuracyScore
|
6
6
|
include BaseEvaluator
|
7
7
|
|
8
|
-
def evaluate(y_pred:, y_true:, x_true: nil)
|
8
|
+
def evaluate(y_pred:, y_true:, x_true: nil, dataset: nil)
|
9
9
|
y_pred = Numo::Int32.cast(y_pred)
|
10
10
|
y_true = Numo::Int32.cast(y_true)
|
11
11
|
y_pred.eq(y_true).count_true.to_f / y_pred.size
|
@@ -23,7 +23,7 @@ module EasyML
|
|
23
23
|
class PrecisionScore
|
24
24
|
include BaseEvaluator
|
25
25
|
|
26
|
-
def evaluate(y_pred:, y_true:, x_true: nil)
|
26
|
+
def evaluate(y_pred:, y_true:, x_true: nil, dataset: nil)
|
27
27
|
y_pred = Numo::Int32.cast(y_pred)
|
28
28
|
y_true = Numo::Int32.cast(y_true)
|
29
29
|
true_positives = (y_pred.eq(1) & y_true.eq(1)).count_true
|
@@ -45,7 +45,7 @@ module EasyML
|
|
45
45
|
class RecallScore
|
46
46
|
include BaseEvaluator
|
47
47
|
|
48
|
-
def evaluate(y_pred:, y_true:, x_true: nil)
|
48
|
+
def evaluate(y_pred:, y_true:, x_true: nil, dataset: nil)
|
49
49
|
y_pred = Numo::Int32.cast(y_pred)
|
50
50
|
y_true = Numo::Int32.cast(y_true)
|
51
51
|
true_positives = (y_pred.eq(1) & y_true.eq(1)).count_true
|
@@ -65,9 +65,9 @@ module EasyML
|
|
65
65
|
class F1Score
|
66
66
|
include BaseEvaluator
|
67
67
|
|
68
|
-
def evaluate(y_pred:, y_true:, x_true: nil)
|
69
|
-
precision = PrecisionScore.new.evaluate(y_pred: y_pred, y_true: y_true)
|
70
|
-
recall = RecallScore.new.evaluate(y_pred: y_pred, y_true: y_true)
|
68
|
+
def evaluate(y_pred:, y_true:, x_true: nil, dataset: nil)
|
69
|
+
precision = PrecisionScore.new.evaluate(y_pred: y_pred, y_true: y_true, dataset: dataset)
|
70
|
+
recall = RecallScore.new.evaluate(y_pred: y_pred, y_true: y_true, dataset: dataset)
|
71
71
|
return 0 unless (precision + recall) > 0
|
72
72
|
|
73
73
|
2 * (precision * recall) / (precision + recall)
|
@@ -85,7 +85,7 @@ module EasyML
|
|
85
85
|
class AUC
|
86
86
|
include BaseEvaluator
|
87
87
|
|
88
|
-
def evaluate(y_pred:, y_true:, x_true: nil)
|
88
|
+
def evaluate(y_pred:, y_true:, x_true: nil, dataset: nil)
|
89
89
|
y_pred = Numo::DFloat.cast(y_pred)
|
90
90
|
y_true = Numo::Int32.cast(y_true)
|
91
91
|
|
@@ -132,8 +132,8 @@ module EasyML
|
|
132
132
|
class ROC_AUC
|
133
133
|
include BaseEvaluator
|
134
134
|
|
135
|
-
def evaluate(y_pred:, y_true:, x_true: nil)
|
136
|
-
AUC.new.evaluate(y_pred: y_pred, y_true: y_true)
|
135
|
+
def evaluate(y_pred:, y_true:, x_true: nil, dataset: nil)
|
136
|
+
AUC.new.evaluate(y_pred: y_pred, y_true: y_true, dataset: dataset)
|
137
137
|
end
|
138
138
|
|
139
139
|
def description
|
@@ -5,7 +5,7 @@ module EasyML
|
|
5
5
|
class MeanAbsoluteError
|
6
6
|
include BaseEvaluator
|
7
7
|
|
8
|
-
def evaluate(y_pred:, y_true:, x_true: nil)
|
8
|
+
def evaluate(y_pred:, y_true:, x_true: nil, dataset: nil)
|
9
9
|
(Numo::DFloat.cast(y_pred) - Numo::DFloat.cast(y_true)).abs.mean
|
10
10
|
end
|
11
11
|
|
@@ -21,7 +21,7 @@ module EasyML
|
|
21
21
|
class MeanSquaredError
|
22
22
|
include BaseEvaluator
|
23
23
|
|
24
|
-
def evaluate(y_pred:, y_true:, x_true: nil)
|
24
|
+
def evaluate(y_pred:, y_true:, x_true: nil, dataset: nil)
|
25
25
|
((Numo::DFloat.cast(y_pred) - Numo::DFloat.cast(y_true)) ** 2).mean
|
26
26
|
end
|
27
27
|
|
@@ -37,7 +37,7 @@ module EasyML
|
|
37
37
|
class RootMeanSquaredError
|
38
38
|
include BaseEvaluator
|
39
39
|
|
40
|
-
def evaluate(y_pred:, y_true:, x_true: nil)
|
40
|
+
def evaluate(y_pred:, y_true:, x_true: nil, dataset: nil)
|
41
41
|
Math.sqrt(((Numo::DFloat.cast(y_pred) - Numo::DFloat.cast(y_true)) ** 2).mean)
|
42
42
|
end
|
43
43
|
|
@@ -61,7 +61,7 @@ module EasyML
|
|
61
61
|
"maximize"
|
62
62
|
end
|
63
63
|
|
64
|
-
def evaluate(y_pred:, y_true:, x_true: nil)
|
64
|
+
def evaluate(y_pred:, y_true:, x_true: nil, dataset: nil)
|
65
65
|
y_true = Numo::DFloat.cast(y_true)
|
66
66
|
y_pred = Numo::DFloat.cast(y_pred)
|
67
67
|
|
@@ -98,13 +98,21 @@ module EasyML
|
|
98
98
|
end
|
99
99
|
end
|
100
100
|
|
101
|
-
def evaluate(model:, y_pred:, y_true:, x_true: nil, evaluator: nil)
|
101
|
+
def evaluate(model:, y_pred:, y_true:, x_true: nil, evaluator: nil, dataset: nil)
|
102
102
|
y_pred = normalize_input(y_pred)
|
103
103
|
y_true = normalize_input(y_true)
|
104
104
|
check_size(y_pred, y_true)
|
105
105
|
|
106
106
|
metrics_results = {}
|
107
107
|
|
108
|
+
if x_true.nil?
|
109
|
+
x_true = model.dataset.test
|
110
|
+
end
|
111
|
+
|
112
|
+
if dataset.nil?
|
113
|
+
dataset = model.dataset.test(all_columns: true)
|
114
|
+
end
|
115
|
+
|
108
116
|
model.metrics.each do |metric|
|
109
117
|
evaluator_class = get(metric.to_sym)
|
110
118
|
next unless evaluator_class
|
@@ -115,6 +123,7 @@ module EasyML
|
|
115
123
|
y_pred: y_pred,
|
116
124
|
y_true: y_true,
|
117
125
|
x_true: x_true,
|
126
|
+
dataset: dataset,
|
118
127
|
)
|
119
128
|
end
|
120
129
|
|
@@ -124,7 +133,7 @@ module EasyML
|
|
124
133
|
raise "Unknown evaluator: #{evaluator}" unless evaluator_class
|
125
134
|
|
126
135
|
evaluator_instance = evaluator_class.new
|
127
|
-
response = evaluator_instance.evaluate(y_pred: y_pred, y_true: y_true, x_true: x_true)
|
136
|
+
response = evaluator_instance.evaluate(y_pred: y_pred, y_true: y_true, x_true: x_true, dataset: dataset)
|
128
137
|
|
129
138
|
if response.is_a?(Hash)
|
130
139
|
metrics_results.merge!(response)
|
@@ -145,6 +154,9 @@ module EasyML
|
|
145
154
|
def normalize_input(input)
|
146
155
|
case input
|
147
156
|
when Array
|
157
|
+
if input.first.class == TrueClass || input.first.class == FalseClass
|
158
|
+
input = input.map { |value| value ? 1.0 : 0.0 }
|
159
|
+
end
|
148
160
|
Numo::DFloat.cast(input)
|
149
161
|
when Polars::DataFrame
|
150
162
|
if input.columns.count > 1
|
@@ -152,7 +164,10 @@ module EasyML
|
|
152
164
|
end
|
153
165
|
|
154
166
|
normalize_input(input[input.columns.first])
|
155
|
-
when Polars::Series
|
167
|
+
when Polars::Series
|
168
|
+
if input.dtype == Polars::Boolean
|
169
|
+
input = input.cast(Polars::Int64)
|
170
|
+
end
|
156
171
|
Numo::DFloat.cast(input)
|
157
172
|
else
|
158
173
|
raise ArgumentError, "Don't know how to evaluate model with y_pred type #{input.class}"
|
data/lib/easy_ml/core/tuner.rb
CHANGED
@@ -8,7 +8,7 @@ module EasyML
|
|
8
8
|
:metrics, :objective, :n_trials, :direction, :evaluator,
|
9
9
|
:study, :results, :adapter, :tune_started_at, :x_true, :y_true,
|
10
10
|
:project_name, :job, :current_run, :trial_enumerator, :progress_block,
|
11
|
-
:tuner_job
|
11
|
+
:tuner_job, :dataset
|
12
12
|
|
13
13
|
def initialize(options = {})
|
14
14
|
@model = options[:model]
|
@@ -77,6 +77,7 @@ module EasyML
|
|
77
77
|
x_true, y_true = model.dataset.test(split_ys: true)
|
78
78
|
self.x_true = x_true
|
79
79
|
self.y_true = y_true
|
80
|
+
self.dataset = model.dataset.test(all_columns: true)
|
80
81
|
adapter.tune_started_at = tune_started_at
|
81
82
|
adapter.y_true = y_true
|
82
83
|
adapter.x_true = x_true
|
@@ -96,14 +97,6 @@ module EasyML
|
|
96
97
|
run_metrics = tune_once
|
97
98
|
result = calculate_result(run_metrics)
|
98
99
|
@results.push(result)
|
99
|
-
|
100
|
-
params = {
|
101
|
-
hyperparameters: model.hyperparameters.to_h,
|
102
|
-
value: result,
|
103
|
-
status: :success,
|
104
|
-
}.compact
|
105
|
-
|
106
|
-
@tuner_run.update!(params)
|
107
100
|
@study.tell(@current_trial, result)
|
108
101
|
rescue StandardError => e
|
109
102
|
@tuner_run.update!(status: :failed, hyperparameters: {})
|
@@ -138,14 +131,27 @@ module EasyML
|
|
138
131
|
)
|
139
132
|
self.current_run = @tuner_run
|
140
133
|
|
141
|
-
adapter.run_trial(@current_trial) do |model|
|
142
|
-
model.
|
143
|
-
|
144
|
-
|
145
|
-
metrics = model.evaluate(y_pred: y_pred, y_true: y_true, x_true: x_true)
|
146
|
-
puts metrics
|
147
|
-
metrics
|
134
|
+
model = adapter.run_trial(@current_trial) do |model|
|
135
|
+
model.tap do
|
136
|
+
model.fit(tuning: true, &progress_block)
|
137
|
+
end
|
148
138
|
end
|
139
|
+
|
140
|
+
y_pred = model.predict(x_true)
|
141
|
+
model.metrics = metrics
|
142
|
+
metrics = model.evaluate(y_pred: y_pred, y_true: y_true, x_true: x_true, dataset: dataset)
|
143
|
+
metric = metrics.symbolize_keys.dig(model.evaluator[:metric].to_sym)
|
144
|
+
|
145
|
+
puts metrics
|
146
|
+
|
147
|
+
params = {
|
148
|
+
hyperparameters: model.hyperparameters.to_h,
|
149
|
+
value: metric,
|
150
|
+
status: :success,
|
151
|
+
}.compact
|
152
|
+
|
153
|
+
@tuner_run.update!(params)
|
154
|
+
metrics
|
149
155
|
end
|
150
156
|
|
151
157
|
private
|
data/lib/easy_ml/version.rb
CHANGED