rumale-model_selection 0.24.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/LICENSE.txt +27 -0
- data/README.md +34 -0
- data/lib/rumale/model_selection/cross_validation.rb +107 -0
- data/lib/rumale/model_selection/function.rb +47 -0
- data/lib/rumale/model_selection/grid_search_cv.rb +213 -0
- data/lib/rumale/model_selection/group_k_fold.rb +93 -0
- data/lib/rumale/model_selection/group_shuffle_split.rb +108 -0
- data/lib/rumale/model_selection/k_fold.rb +78 -0
- data/lib/rumale/model_selection/shuffle_split.rb +86 -0
- data/lib/rumale/model_selection/stratified_k_fold.rb +95 -0
- data/lib/rumale/model_selection/stratified_shuffle_split.rb +115 -0
- data/lib/rumale/model_selection/time_series_split.rb +89 -0
- data/lib/rumale/model_selection/version.rb +8 -0
- data/lib/rumale/model_selection.rb +16 -0
- metadata +121 -0
checksums.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
1
|
+
---
|
2
|
+
SHA256:
|
3
|
+
metadata.gz: 28fa6e0e0c5832366f7130aa8185283b9ae316220b4cebec64774acab7f5cfd6
|
4
|
+
data.tar.gz: c85aa131bbcc278f99f2aa91dc6d529c291a2274dbf46c96b468191fefc8015d
|
5
|
+
SHA512:
|
6
|
+
metadata.gz: ad3df526ec16120843536a9f199ba3f0c5856f40ac5a9de545b34b43bb60a798ad17b1c75588e7cfc59caa34bb886c7f5f180aacb9f24e338d70ebd389bd58aa
|
7
|
+
data.tar.gz: 8214f4cec33f056f49d2a479d23f16bc264b92c1e457af695c20d6b5b5233f78c2ab735d6ed584c00ce5333ec9779a8966d6bfd5c96128b7984cda6269b88732
|
data/LICENSE.txt
ADDED
@@ -0,0 +1,27 @@
|
|
1
|
+
Copyright (c) 2022 Atsushi Tatsuma
|
2
|
+
All rights reserved.
|
3
|
+
|
4
|
+
Redistribution and use in source and binary forms, with or without
|
5
|
+
modification, are permitted provided that the following conditions are met:
|
6
|
+
|
7
|
+
* Redistributions of source code must retain the above copyright notice, this
|
8
|
+
list of conditions and the following disclaimer.
|
9
|
+
|
10
|
+
* Redistributions in binary form must reproduce the above copyright notice,
|
11
|
+
this list of conditions and the following disclaimer in the documentation
|
12
|
+
and/or other materials provided with the distribution.
|
13
|
+
|
14
|
+
* Neither the name of the copyright holder nor the names of its
|
15
|
+
contributors may be used to endorse or promote products derived from
|
16
|
+
this software without specific prior written permission.
|
17
|
+
|
18
|
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
19
|
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
20
|
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
21
|
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
22
|
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
23
|
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
24
|
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
25
|
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
26
|
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
27
|
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
data/README.md
ADDED
@@ -0,0 +1,34 @@
|
|
1
|
+
# Rumale::ModelSelection
|
2
|
+
|
3
|
+
[](https://badge.fury.io/rb/rumale-model_selection)
|
4
|
+
[](https://github.com/yoshoku/rumale/blob/main/rumale-model_selection/LICENSE.txt)
|
5
|
+
[](https://yoshoku.github.io/rumale/doc/Rumale/ModelSelection.html)
|
6
|
+
|
7
|
+
Rumale is a machine learning library in Ruby.
|
8
|
+
Rumale::ModelSelection provides model validation techniques,
|
9
|
+
such as k-fold cross-validation, time series cross-validation, and grid search,
|
10
|
+
with Rumale interface.
|
11
|
+
|
12
|
+
## Installation
|
13
|
+
|
14
|
+
Add this line to your application's Gemfile:
|
15
|
+
|
16
|
+
```ruby
|
17
|
+
gem 'rumale-model_selection'
|
18
|
+
```
|
19
|
+
|
20
|
+
And then execute:
|
21
|
+
|
22
|
+
$ bundle install
|
23
|
+
|
24
|
+
Or install it yourself as:
|
25
|
+
|
26
|
+
$ gem install rumale-model_selection
|
27
|
+
|
28
|
+
## Documentation
|
29
|
+
|
30
|
+
- [Rumale API Documentation - ModelSelection](https://yoshoku.github.io/rumale/doc/Rumale/ModelSelection.html)
|
31
|
+
|
32
|
+
## License
|
33
|
+
|
34
|
+
The gem is available as open source under the terms of the [BSD-3-Clause License](https://opensource.org/licenses/BSD-3-Clause).
|
@@ -0,0 +1,107 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'rumale/evaluation_measure/log_loss'
|
4
|
+
|
5
|
+
module Rumale
|
6
|
+
# This module consists of the classes for model validation techniques.
|
7
|
+
module ModelSelection
|
8
|
+
# CrossValidation is a class that evaluates a given classifier with cross-validation method.
|
9
|
+
#
|
10
|
+
# @example
|
11
|
+
# require 'rumale/linear_model'
|
12
|
+
# require 'rumale/model_selection/stratified_k_fold'
|
13
|
+
# require 'rumale/model_selection/cross_validation'
|
14
|
+
#
|
15
|
+
# svc = Rumale::LinearModel::SVC.new
|
16
|
+
# kf = Rumale::ModelSelection::StratifiedKFold.new(n_splits: 5)
|
17
|
+
# cv = Rumale::ModelSelection::CrossValidation.new(estimator: svc, splitter: kf)
|
18
|
+
# report = cv.perform(samples, labels)
|
19
|
+
# mean_test_score = report[:test_score].inject(:+) / kf.n_splits
|
20
|
+
#
|
21
|
+
class CrossValidation
|
22
|
+
# Return the classifier of which performance is evaluated.
|
23
|
+
# @return [Classifier]
|
24
|
+
attr_reader :estimator
|
25
|
+
|
26
|
+
# Return the splitter that divides dataset.
|
27
|
+
# @return [Splitter]
|
28
|
+
attr_reader :splitter
|
29
|
+
|
30
|
+
# Return the evaluator that calculates score.
|
31
|
+
# @return [Evaluator]
|
32
|
+
attr_reader :evaluator
|
33
|
+
|
34
|
+
# Return the flag indicating whether to caculate the score of training dataset.
|
35
|
+
# @return [Boolean]
|
36
|
+
attr_reader :return_train_score
|
37
|
+
|
38
|
+
# Create a new evaluator with cross-validation method.
|
39
|
+
#
|
40
|
+
# @param estimator [Classifier] The classifier of which performance is evaluated.
|
41
|
+
# @param splitter [Splitter] The splitter that divides dataset to training and testing dataset.
|
42
|
+
# @param evaluator [Evaluator] The evaluator that calculates score of estimator results.
|
43
|
+
# @param return_train_score [Boolean] The flag indicating whether to calculate the score of training dataset.
|
44
|
+
def initialize(estimator: nil, splitter: nil, evaluator: nil, return_train_score: false)
|
45
|
+
@estimator = estimator
|
46
|
+
@splitter = splitter
|
47
|
+
@evaluator = evaluator
|
48
|
+
@return_train_score = return_train_score
|
49
|
+
end
|
50
|
+
|
51
|
+
# Perform the evalution of given classifier with cross-validation method.
|
52
|
+
#
|
53
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features])
|
54
|
+
# The dataset to be used to evaluate the estimator.
|
55
|
+
# @param y [Numo::Int32 / Numo::DFloat] (shape: [n_samples] / [n_samples, n_outputs])
|
56
|
+
# The labels to be used to evaluate the classifier / The target values to be used to evaluate the regressor.
|
57
|
+
# @return [Hash] The report summarizing the results of cross-validation.
|
58
|
+
# * :fit_time (Array<Float>) The calculation times of fitting the estimator for each split.
|
59
|
+
# * :test_score (Array<Float>) The scores of testing dataset for each split.
|
60
|
+
# * :train_score (Array<Float>) The scores of training dataset for each split. This option is nil if
|
61
|
+
# the return_train_score is false.
|
62
|
+
def perform(x, y)
|
63
|
+
# Initialize the report of cross validation.
|
64
|
+
report = { test_score: [], train_score: nil, fit_time: [] }
|
65
|
+
report[:train_score] = [] if @return_train_score
|
66
|
+
# Evaluate the estimator on each split.
|
67
|
+
@splitter.split(x, y).each do |train_ids, test_ids|
|
68
|
+
# Split dataset into training and testing dataset.
|
69
|
+
feature_ids = !kernel_machine? || train_ids
|
70
|
+
train_x = x[train_ids, feature_ids]
|
71
|
+
train_y = y.shape[1].nil? ? y[train_ids] : y[train_ids, true]
|
72
|
+
test_x = x[test_ids, feature_ids]
|
73
|
+
test_y = y.shape[1].nil? ? y[test_ids] : y[test_ids, true]
|
74
|
+
# Fit the estimator.
|
75
|
+
start_time = Time.now.to_i
|
76
|
+
@estimator.fit(train_x, train_y)
|
77
|
+
# Calculate scores and prepare the report.
|
78
|
+
report[:fit_time].push(Time.now.to_i - start_time)
|
79
|
+
if @evaluator.nil?
|
80
|
+
report[:test_score].push(@estimator.score(test_x, test_y))
|
81
|
+
report[:train_score].push(@estimator.score(train_x, train_y)) if @return_train_score
|
82
|
+
elsif log_loss?
|
83
|
+
report[:test_score].push(@evaluator.score(test_y, @estimator.predict_proba(test_x)))
|
84
|
+
if @return_train_score
|
85
|
+
report[:train_score].push(@evaluator.score(train_y,
|
86
|
+
@estimator.predict_proba(train_x)))
|
87
|
+
end
|
88
|
+
else
|
89
|
+
report[:test_score].push(@evaluator.score(test_y, @estimator.predict(test_x)))
|
90
|
+
report[:train_score].push(@evaluator.score(train_y, @estimator.predict(train_x))) if @return_train_score
|
91
|
+
end
|
92
|
+
end
|
93
|
+
report
|
94
|
+
end
|
95
|
+
|
96
|
+
private
|
97
|
+
|
98
|
+
def kernel_machine?
|
99
|
+
@estimator.class.name.include?('Rumale::KernelMachine')
|
100
|
+
end
|
101
|
+
|
102
|
+
def log_loss?
|
103
|
+
@evaluator.is_a?(::Rumale::EvaluationMeasure::LogLoss)
|
104
|
+
end
|
105
|
+
end
|
106
|
+
end
|
107
|
+
end
|
@@ -0,0 +1,47 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'numo/narray'
|
4
|
+
|
5
|
+
require 'rumale/model_selection/shuffle_split'
|
6
|
+
require 'rumale/model_selection/stratified_shuffle_split'
|
7
|
+
|
8
|
+
module Rumale
|
9
|
+
# This module consists of the classes for model validation techniques.
|
10
|
+
module ModelSelection
|
11
|
+
module_function
|
12
|
+
|
13
|
+
# Split randomly data set into test and train data.
|
14
|
+
#
|
15
|
+
# @example
|
16
|
+
# require 'rumale/model_selection/function'
|
17
|
+
#
|
18
|
+
# x_train, x_test, y_train, y_test = Rumale::ModelSelection.train_test_split(x, y, test_size: 0.2, stratify: true, random_seed: 1)
|
19
|
+
#
|
20
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The dataset to be used to generate data indices.
|
21
|
+
# @param y [Numo::Int32] (shape: [n_samples]) The labels to be used to generate data indices for stratified random permutation.
|
22
|
+
# If stratify = false, this parameter is ignored.
|
23
|
+
# @param test_size [Float] The ratio of number of samples for test data.
|
24
|
+
# @param train_size [Float] The ratio of number of samples for train data.
|
25
|
+
# If nil is given, it sets to 1 - test_size.
|
26
|
+
# @param stratify [Boolean] The flag indicating whether to perform stratify split.
|
27
|
+
# @param random_seed [Integer] The seed value using to initialize the random generator.
|
28
|
+
# @return [Array<Numo::NArray>] The set of training and testing data.
|
29
|
+
def train_test_split(x, y = nil, test_size: 0.1, train_size: nil, stratify: false, random_seed: nil)
|
30
|
+
splitter = if stratify
|
31
|
+
::Rumale::ModelSelection::StratifiedShuffleSplit.new(
|
32
|
+
n_splits: 1, test_size: test_size, train_size: train_size, random_seed: random_seed
|
33
|
+
)
|
34
|
+
else
|
35
|
+
::Rumale::ModelSelection::ShuffleSplit.new(
|
36
|
+
n_splits: 1, test_size: test_size, train_size: train_size, random_seed: random_seed
|
37
|
+
)
|
38
|
+
end
|
39
|
+
train_ids, test_ids = splitter.split(x, y).first
|
40
|
+
x_train = x[train_ids, true].dup
|
41
|
+
y_train = y[train_ids].dup
|
42
|
+
x_test = x[test_ids, true].dup
|
43
|
+
y_test = y[test_ids].dup
|
44
|
+
[x_train, x_test, y_train, y_test]
|
45
|
+
end
|
46
|
+
end
|
47
|
+
end
|
@@ -0,0 +1,213 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'rumale/base/estimator'
|
4
|
+
require 'rumale/model_selection/cross_validation'
|
5
|
+
|
6
|
+
module Rumale
|
7
|
+
module ModelSelection
|
8
|
+
# GridSearchCV is a class that performs hyperparameter optimization with grid search method.
|
9
|
+
#
|
10
|
+
# @example
|
11
|
+
# require 'rumale/ensemble'
|
12
|
+
# require 'rumale/model_selection/stratified_k_fold'
|
13
|
+
# require 'rumale/model_selection/grid_search_cv'
|
14
|
+
#
|
15
|
+
# rfc = Rumale::Ensemble::RandomForestClassifier.new(random_seed: 1)
|
16
|
+
# pg = { n_estimators: [5, 10], max_depth: [3, 5], max_leaf_nodes: [15, 31] }
|
17
|
+
# kf = Rumale::ModelSelection::StratifiedKFold.new(n_splits: 5)
|
18
|
+
# gs = Rumale::ModelSelection::GridSearchCV.new(estimator: rfc, param_grid: pg, splitter: kf)
|
19
|
+
# gs.fit(samples, labels)
|
20
|
+
# p gs.cv_results
|
21
|
+
# p gs.best_params
|
22
|
+
#
|
23
|
+
# @example
|
24
|
+
# rbf = Rumale::KernelApproximation::RBF.new(random_seed: 1)
|
25
|
+
# svc = Rumale::LinearModel::SVC.new(random_seed: 1)
|
26
|
+
# pipe = Rumale::Pipeline::Pipeline.new(steps: { rbf: rbf, svc: svc })
|
27
|
+
# pg = { rbf__gamma: [32.0, 1.0], rbf__n_components: [4, 128], svc__reg_param: [16.0, 0.1] }
|
28
|
+
# kf = Rumale::ModelSelection::StratifiedKFold.new(n_splits: 5)
|
29
|
+
# gs = Rumale::ModelSelection::GridSearchCV.new(estimator: pipe, param_grid: pg, splitter: kf)
|
30
|
+
# gs.fit(samples, labels)
|
31
|
+
# p gs.cv_results
|
32
|
+
# p gs.best_params
|
33
|
+
#
|
34
|
+
class GridSearchCV < ::Rumale::Base::Estimator
|
35
|
+
# Return the result of cross validation for each parameter.
|
36
|
+
# @return [Hash]
|
37
|
+
attr_reader :cv_results
|
38
|
+
|
39
|
+
# Return the score of the estimator learned with the best parameter.
|
40
|
+
# @return [Float]
|
41
|
+
attr_reader :best_score
|
42
|
+
|
43
|
+
# Return the best parameter set.
|
44
|
+
# @return [Hash]
|
45
|
+
attr_reader :best_params
|
46
|
+
|
47
|
+
# Return the index of the best parameter.
|
48
|
+
# @return [Integer]
|
49
|
+
attr_reader :best_index
|
50
|
+
|
51
|
+
# Return the estimator learned with the best parameter.
|
52
|
+
# @return [Estimator]
|
53
|
+
attr_reader :best_estimator
|
54
|
+
|
55
|
+
# Create a new grid search method.
|
56
|
+
#
|
57
|
+
# @param estimator [Classifier/Regresor] The estimator to be searched for optimal parameters with grid search method.
|
58
|
+
# @param param_grid [Array<Hash>] The parameter sets is represented with array of hash that
|
59
|
+
# consists of parameter names as keys and array of parameter values as values.
|
60
|
+
# @param splitter [Splitter] The splitter that divides dataset to training and testing dataset on cross validation.
|
61
|
+
# @param evaluator [Evaluator] The evaluator that calculates score of estimator results on cross validation.
|
62
|
+
# If nil is given, the score method of estimator is used to evaluation.
|
63
|
+
# @param greater_is_better [Boolean] The flag that indicates whether the estimator is better as
|
64
|
+
# evaluation score is larger.
|
65
|
+
def initialize(estimator: nil, param_grid: nil, splitter: nil, evaluator: nil, greater_is_better: true)
|
66
|
+
super()
|
67
|
+
@params = {
|
68
|
+
param_grid: valid_param_grid(param_grid),
|
69
|
+
estimator: Marshal.load(Marshal.dump(estimator)),
|
70
|
+
splitter: Marshal.load(Marshal.dump(splitter)),
|
71
|
+
evaluator: Marshal.load(Marshal.dump(evaluator)),
|
72
|
+
greater_is_better: greater_is_better
|
73
|
+
}
|
74
|
+
end
|
75
|
+
|
76
|
+
# Fit the model with given training data and all sets of parameters.
|
77
|
+
#
|
78
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
|
79
|
+
# @param y [Numo::NArray] (shape: [n_samples, n_outputs]) The target values or labels to be used for fitting the model.
|
80
|
+
# @return [GridSearchCV] The learned estimator with grid search.
|
81
|
+
def fit(x, y)
|
82
|
+
init_attrs
|
83
|
+
|
84
|
+
param_combinations.each do |prm_set|
|
85
|
+
prm_set.each do |prms|
|
86
|
+
report = perform_cross_validation(x, y, prms)
|
87
|
+
store_cv_result(prms, report)
|
88
|
+
end
|
89
|
+
end
|
90
|
+
|
91
|
+
find_best_params
|
92
|
+
|
93
|
+
@best_estimator = configurated_estimator(@best_params)
|
94
|
+
@best_estimator.fit(x, y)
|
95
|
+
self
|
96
|
+
end
|
97
|
+
|
98
|
+
# Call the decision_function method of learned estimator with the best parameter.
|
99
|
+
#
|
100
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to compute the scores.
|
101
|
+
# @return [Numo::DFloat] (shape: [n_samples]) Confidence score per sample.
|
102
|
+
def decision_function(x)
|
103
|
+
@best_estimator.decision_function(x)
|
104
|
+
end
|
105
|
+
|
106
|
+
# Call the predict method of learned estimator with the best parameter.
|
107
|
+
#
|
108
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to obtain prediction result.
|
109
|
+
# @return [Numo::NArray] Predicted results.
|
110
|
+
def predict(x)
|
111
|
+
@best_estimator.predict(x)
|
112
|
+
end
|
113
|
+
|
114
|
+
# Call the predict_log_proba method of learned estimator with the best parameter.
|
115
|
+
#
|
116
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the log-probailities.
|
117
|
+
# @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted log-probability of each class per sample.
|
118
|
+
def predict_log_proba(x)
|
119
|
+
@best_estimator.predict_log_proba(x)
|
120
|
+
end
|
121
|
+
|
122
|
+
# Call the predict_proba method of learned estimator with the best parameter.
|
123
|
+
#
|
124
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the probailities.
|
125
|
+
# @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted probability of each class per sample.
|
126
|
+
def predict_proba(x)
|
127
|
+
@best_estimator.predict_proba(x)
|
128
|
+
end
|
129
|
+
|
130
|
+
# Call the score method of learned estimator with the best parameter.
|
131
|
+
#
|
132
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) Testing data.
|
133
|
+
# @param y [Numo::NArray] (shape: [n_samples, n_outputs]) True target values or labels for testing data.
|
134
|
+
# @return [Float] The score of estimator.
|
135
|
+
def score(x, y)
|
136
|
+
@best_estimator.score(x, y)
|
137
|
+
end
|
138
|
+
|
139
|
+
private
|
140
|
+
|
141
|
+
def valid_param_grid(grid)
|
142
|
+
raise TypeError, 'Expect class of param_grid to be Hash or Array' unless grid.is_a?(Hash) || grid.is_a?(Array)
|
143
|
+
|
144
|
+
grid = [grid] if grid.is_a?(Hash)
|
145
|
+
grid.each do |h|
|
146
|
+
raise TypeError, 'Expect class of elements in param_grid to be Hash' unless h.is_a?(Hash)
|
147
|
+
raise TypeError, 'Expect class of parameter values in param_grid to be Array' unless h.values.all?(Array)
|
148
|
+
end
|
149
|
+
grid
|
150
|
+
end
|
151
|
+
|
152
|
+
def param_combinations
|
153
|
+
@param_combinations ||= @params[:param_grid].map do |prm|
|
154
|
+
x = prm.sort.to_h.map { |k, v| [k].product(v) }
|
155
|
+
x[0].product(*x[1...x.size]).map(&:to_h)
|
156
|
+
end
|
157
|
+
end
|
158
|
+
|
159
|
+
def perform_cross_validation(x, y, prms)
|
160
|
+
est = configurated_estimator(prms)
|
161
|
+
cv = ::Rumale::ModelSelection::CrossValidation.new(estimator: est, splitter: @params[:splitter],
|
162
|
+
evaluator: @params[:evaluator], return_train_score: true)
|
163
|
+
cv.perform(x, y)
|
164
|
+
end
|
165
|
+
|
166
|
+
def configurated_estimator(prms)
|
167
|
+
estimator = Marshal.load(Marshal.dump(@params[:estimator]))
|
168
|
+
if pipeline?
|
169
|
+
prms.each do |k, v|
|
170
|
+
est_name, prm_name = k.to_s.split('__')
|
171
|
+
estimator.steps[est_name.to_sym].params[prm_name.to_sym] = v
|
172
|
+
end
|
173
|
+
else
|
174
|
+
prms.each { |k, v| estimator.params[k] = v }
|
175
|
+
end
|
176
|
+
estimator
|
177
|
+
end
|
178
|
+
|
179
|
+
def init_attrs
|
180
|
+
@cv_results = %i[mean_test_score std_test_score
|
181
|
+
mean_train_score std_train_score
|
182
|
+
mean_fit_time std_fit_time params].to_h { |v| [v, []] }
|
183
|
+
@best_score = nil
|
184
|
+
@best_params = nil
|
185
|
+
@best_index = nil
|
186
|
+
@best_estimator = nil
|
187
|
+
end
|
188
|
+
|
189
|
+
def store_cv_result(prms, report)
|
190
|
+
test_scores = Numo::DFloat[*report[:test_score]]
|
191
|
+
train_scores = Numo::DFloat[*report[:train_score]]
|
192
|
+
fit_times = Numo::DFloat[*report[:fit_time]]
|
193
|
+
@cv_results[:mean_test_score].push(test_scores.mean)
|
194
|
+
@cv_results[:std_test_score].push(test_scores.stddev)
|
195
|
+
@cv_results[:mean_train_score].push(train_scores.mean)
|
196
|
+
@cv_results[:std_train_score].push(train_scores.stddev)
|
197
|
+
@cv_results[:mean_fit_time].push(fit_times.mean)
|
198
|
+
@cv_results[:std_fit_time].push(fit_times.stddev)
|
199
|
+
@cv_results[:params].push(prms)
|
200
|
+
end
|
201
|
+
|
202
|
+
def find_best_params
|
203
|
+
@best_score = @params[:greater_is_better] ? @cv_results[:mean_test_score].max : @cv_results[:mean_test_score].min
|
204
|
+
@best_index = @cv_results[:mean_test_score].index(@best_score)
|
205
|
+
@best_params = @cv_results[:params][@best_index]
|
206
|
+
end
|
207
|
+
|
208
|
+
def pipeline?
|
209
|
+
@params[:estimator].class.name.include?('Rumale::Pipeline')
|
210
|
+
end
|
211
|
+
end
|
212
|
+
end
|
213
|
+
end
|
@@ -0,0 +1,93 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'rumale/base/splitter'
|
4
|
+
require 'rumale/preprocessing/label_encoder'
|
5
|
+
|
6
|
+
module Rumale
|
7
|
+
module ModelSelection
|
8
|
+
# GroupKFold is a class that generates the set of data indices for K-fold cross-validation.
|
9
|
+
# The data points belonging to the same group do not be split into different folds.
|
10
|
+
# The number of groups should be greater than or equal to the number of splits.
|
11
|
+
#
|
12
|
+
# @example
|
13
|
+
# require 'rumale/model_selection/group_k_fold'
|
14
|
+
#
|
15
|
+
# cv = Rumale::ModelSelection::GroupKFold.new(n_splits: 3)
|
16
|
+
# x = Numo::DFloat.new(8, 2).rand
|
17
|
+
# groups = Numo::Int32[1, 1, 1, 2, 2, 3, 3, 3]
|
18
|
+
# cv.split(x, nil, groups).each do |train_ids, test_ids|
|
19
|
+
# puts '---'
|
20
|
+
# pp train_ids
|
21
|
+
# pp test_ids
|
22
|
+
# end
|
23
|
+
#
|
24
|
+
# # ---
|
25
|
+
# # [0, 1, 2, 3, 4]
|
26
|
+
# # [5, 6, 7]
|
27
|
+
# # ---
|
28
|
+
# # [3, 4, 5, 6, 7]
|
29
|
+
# # [0, 1, 2]
|
30
|
+
# # ---
|
31
|
+
# # [0, 1, 2, 5, 6, 7]
|
32
|
+
# # [3, 4]
|
33
|
+
#
|
34
|
+
class GroupKFold
|
35
|
+
include ::Rumale::Base::Splitter
|
36
|
+
|
37
|
+
# Return the number of folds.
|
38
|
+
# @return [Integer]
|
39
|
+
attr_reader :n_splits
|
40
|
+
|
41
|
+
# Create a new data splitter for grouped K-fold cross validation.
|
42
|
+
#
|
43
|
+
# @param n_splits [Integer] The number of folds.
|
44
|
+
def initialize(n_splits: 5)
|
45
|
+
@n_splits = n_splits
|
46
|
+
end
|
47
|
+
|
48
|
+
# Generate data indices for grouped K-fold cross validation.
|
49
|
+
#
|
50
|
+
# @overload split(x, y, groups) -> Array
|
51
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features])
|
52
|
+
# The dataset to be used to generate data indices for grouped K-fold cross validation.
|
53
|
+
# @param y [Numo::Int32] (shape: [n_samples])
|
54
|
+
# This argument exists to unify the interface between the K-fold methods, it is not used in the method.
|
55
|
+
# @param groups [Numo::Int32] (shape: [n_samples])
|
56
|
+
# The group labels to be used to generate data indices for grouped K-fold cross validation.
|
57
|
+
# @return [Array] The set of data indices for constructing the training and testing dataset in each fold.
|
58
|
+
def split(x, _y, groups)
|
59
|
+
encoder = ::Rumale::Preprocessing::LabelEncoder.new
|
60
|
+
groups = encoder.fit_transform(groups)
|
61
|
+
n_groups = encoder.classes.size
|
62
|
+
|
63
|
+
if n_groups < @n_splits
|
64
|
+
raise ArgumentError,
|
65
|
+
'The number of groups should be greater than or equal to the number of splits.'
|
66
|
+
end
|
67
|
+
|
68
|
+
n_samples_per_group = groups.bincount
|
69
|
+
group_ids = n_samples_per_group.sort_index.reverse
|
70
|
+
n_samples_per_group = n_samples_per_group[group_ids]
|
71
|
+
|
72
|
+
n_samples_per_fold = Numo::Int32.zeros(@n_splits)
|
73
|
+
group_to_fold = Numo::Int32.zeros(n_groups)
|
74
|
+
|
75
|
+
n_samples_per_group.each_with_index do |weight, id|
|
76
|
+
min_sample_fold_id = n_samples_per_fold.min_index
|
77
|
+
n_samples_per_fold[min_sample_fold_id] += weight
|
78
|
+
group_to_fold[group_ids[id]] = min_sample_fold_id
|
79
|
+
end
|
80
|
+
|
81
|
+
n_samples = x.shape[0]
|
82
|
+
sample_ids = Array(0...n_samples)
|
83
|
+
fold_ids = group_to_fold[groups]
|
84
|
+
|
85
|
+
Array.new(@n_splits) do |fid|
|
86
|
+
test_ids = fold_ids.eq(fid).where.to_a
|
87
|
+
train_ids = sample_ids - test_ids
|
88
|
+
[train_ids, test_ids]
|
89
|
+
end
|
90
|
+
end
|
91
|
+
end
|
92
|
+
end
|
93
|
+
end
|
@@ -0,0 +1,108 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'rumale/base/splitter'
|
4
|
+
|
5
|
+
module Rumale
|
6
|
+
module ModelSelection
|
7
|
+
# GroupShuffleSplit is a class that generates the set of data indices
|
8
|
+
# for random permutation cross-validation by randomly selecting group labels.
|
9
|
+
#
|
10
|
+
# @example
|
11
|
+
# require 'rumale/model_selection/group_shuffle_split'
|
12
|
+
#
|
13
|
+
# cv = Rumale::ModelSelection::GroupShuffleSplit.new(n_splits: 2, test_size: 0.2, random_seed: 1)
|
14
|
+
# x = Numo::DFloat.new(8, 2).rand
|
15
|
+
# groups = Numo::Int32[1, 1, 1, 2, 2, 3, 3, 3]
|
16
|
+
# cv.split(x, nil, groups).each do |train_ids, test_ids|
|
17
|
+
# puts '---'
|
18
|
+
# pp train_ids
|
19
|
+
# pp test_ids
|
20
|
+
# end
|
21
|
+
#
|
22
|
+
# # ---
|
23
|
+
# # [0, 1, 2, 5, 6, 7]
|
24
|
+
# # [3, 4]
|
25
|
+
# # ---
|
26
|
+
# # [3, 4, 5, 6, 7]
|
27
|
+
# # [0, 1, 2]
|
28
|
+
#
|
29
|
+
class GroupShuffleSplit
|
30
|
+
include Rumale::Base::Splitter
|
31
|
+
|
32
|
+
# Return the number of folds.
|
33
|
+
# @return [Integer]
|
34
|
+
attr_reader :n_splits
|
35
|
+
|
36
|
+
# Return the random generator for shuffling the dataset.
|
37
|
+
# @return [Random]
|
38
|
+
attr_reader :rng
|
39
|
+
|
40
|
+
# Create a new data splitter for random permutation cross validation with given group labels.
|
41
|
+
#
|
42
|
+
# @param n_splits [Integer] The number of folds.
|
43
|
+
# @param test_size [Float] The ratio of number of groups for test data.
|
44
|
+
# @param train_size [Float/Nil] The ratio of number of groups for train data.
|
45
|
+
# @param random_seed [Integer] The seed value using to initialize the random generator.
|
46
|
+
def initialize(n_splits: 5, test_size: 0.2, train_size: nil, random_seed: nil)
|
47
|
+
@n_splits = n_splits
|
48
|
+
@test_size = test_size
|
49
|
+
@train_size = train_size
|
50
|
+
@random_seed = random_seed
|
51
|
+
@random_seed ||= srand
|
52
|
+
@rng = Random.new(@random_seed)
|
53
|
+
end
|
54
|
+
|
55
|
+
# Generate train and test data indices by randomly selecting group labels.
|
56
|
+
#
|
57
|
+
# @overload split(x, y, groups) -> Array
|
58
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features])
|
59
|
+
# The dataset to be used to generate data indices for random permutation cross validation.
|
60
|
+
# @param y [Numo::Int32] (shape: [n_samples])
|
61
|
+
# This argument exists to unify the interface between the K-fold methods, it is not used in the method.
|
62
|
+
# @param groups [Numo::Int32] (shape: [n_samples])
|
63
|
+
# The group labels to be used to generate data indices for random permutation cross validation.
|
64
|
+
# @return [Array] The set of data indices for constructing the training and testing dataset in each fold.
|
65
|
+
def split(_x, _y, groups)
|
66
|
+
classes = groups.to_a.uniq.sort
|
67
|
+
n_groups = classes.size
|
68
|
+
n_test_groups = (@test_size * n_groups).ceil.to_i
|
69
|
+
n_train_groups = @train_size.nil? ? n_groups - n_test_groups : (@train_size * n_groups).floor.to_i
|
70
|
+
|
71
|
+
unless n_test_groups.between?(1, n_groups)
|
72
|
+
raise RangeError,
|
73
|
+
'The number of groups in test split must be not less than 1 and not more than the number of groups.'
|
74
|
+
end
|
75
|
+
unless n_train_groups.between?(1, n_groups)
|
76
|
+
raise RangeError,
|
77
|
+
'The number of groups in train split must be not less than 1 and not more than the number of groups.'
|
78
|
+
end
|
79
|
+
if (n_test_groups + n_train_groups) > n_groups
|
80
|
+
raise RangeError,
|
81
|
+
'The total number of groups in test split and train split must be not more than the number of groups.'
|
82
|
+
end
|
83
|
+
|
84
|
+
sub_rng = @rng.dup
|
85
|
+
|
86
|
+
Array.new(@n_splits) do
|
87
|
+
test_group_ids = classes.sample(n_test_groups, random: sub_rng)
|
88
|
+
train_group_ids = if @train_size.nil?
|
89
|
+
classes - test_group_ids
|
90
|
+
else
|
91
|
+
(classes - test_group_ids).sample(n_train_groups, random: sub_rng)
|
92
|
+
end
|
93
|
+
test_ids = in1d(groups, test_group_ids).where.to_a
|
94
|
+
train_ids = in1d(groups, train_group_ids).where.to_a
|
95
|
+
[train_ids, test_ids]
|
96
|
+
end
|
97
|
+
end
|
98
|
+
|
99
|
+
private
|
100
|
+
|
101
|
+
def in1d(a, b)
|
102
|
+
res = Numo::Bit.zeros(a.shape[0])
|
103
|
+
b.each { |v| res |= a.eq(v) }
|
104
|
+
res
|
105
|
+
end
|
106
|
+
end
|
107
|
+
end
|
108
|
+
end
|
@@ -0,0 +1,78 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'rumale/base/splitter'
|
4
|
+
|
5
|
+
module Rumale
|
6
|
+
# This module consists of the classes for model validation techniques.
|
7
|
+
module ModelSelection
|
8
|
+
# KFold is a class that generates the set of data indices for K-fold cross-validation.
|
9
|
+
#
|
10
|
+
# @example
|
11
|
+
# require 'rumale/model_selection/k_fold'
|
12
|
+
#
|
13
|
+
# kf = Rumale::ModelSelection::KFold.new(n_splits: 3, shuffle: true, random_seed: 1)
|
14
|
+
# kf.split(samples, labels).each do |train_ids, test_ids|
|
15
|
+
# train_samples = samples[train_ids, true]
|
16
|
+
# test_samples = samples[test_ids, true]
|
17
|
+
# ...
|
18
|
+
# end
|
19
|
+
#
|
20
|
+
class KFold
|
21
|
+
include ::Rumale::Base::Splitter
|
22
|
+
|
23
|
+
# Return the number of folds.
|
24
|
+
# @return [Integer]
|
25
|
+
attr_reader :n_splits
|
26
|
+
|
27
|
+
# Return the flag indicating whether to shuffle the dataset.
|
28
|
+
# @return [Boolean]
|
29
|
+
attr_reader :shuffle
|
30
|
+
|
31
|
+
# Return the random generator for shuffling the dataset.
|
32
|
+
# @return [Random]
|
33
|
+
attr_reader :rng
|
34
|
+
|
35
|
+
# Create a new data splitter for K-fold cross validation.
|
36
|
+
#
|
37
|
+
# @param n_splits [Integer] The number of folds.
|
38
|
+
# @param shuffle [Boolean] The flag indicating whether to shuffle the dataset.
|
39
|
+
# @param random_seed [Integer] The seed value using to initialize the random generator.
|
40
|
+
def initialize(n_splits: 3, shuffle: false, random_seed: nil)
|
41
|
+
@n_splits = n_splits
|
42
|
+
@shuffle = shuffle
|
43
|
+
@random_seed = random_seed
|
44
|
+
@random_seed ||= srand
|
45
|
+
@rng = Random.new(@random_seed)
|
46
|
+
end
|
47
|
+
|
48
|
+
# Generate data indices for K-fold cross validation.
|
49
|
+
#
|
50
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features])
|
51
|
+
# The dataset to be used to generate data indices for K-fold cross validation.
|
52
|
+
# @return [Array] The set of data indices for constructing the training and testing dataset in each fold.
|
53
|
+
def split(x, _y = nil)
|
54
|
+
# Initialize and check some variables.
|
55
|
+
n_samples, = x.shape
|
56
|
+
unless @n_splits.between?(2, n_samples)
|
57
|
+
raise ArgumentError,
|
58
|
+
'The value of n_splits must be not less than 2 and not more than the number of samples.'
|
59
|
+
end
|
60
|
+
sub_rng = @rng.dup
|
61
|
+
# Splits dataset ids to each fold.
|
62
|
+
dataset_ids = Array(0...n_samples)
|
63
|
+
dataset_ids.shuffle!(random: sub_rng) if @shuffle
|
64
|
+
fold_sets = Array.new(@n_splits) do |n|
|
65
|
+
n_fold_samples = n_samples / @n_splits
|
66
|
+
n_fold_samples += 1 if n < n_samples % @n_splits
|
67
|
+
dataset_ids.shift(n_fold_samples)
|
68
|
+
end
|
69
|
+
# Returns array consisting of the training and testing ids for each fold.
|
70
|
+
Array.new(@n_splits) do |n|
|
71
|
+
train_ids = fold_sets.select.with_index { |_, id| id != n }.flatten
|
72
|
+
test_ids = fold_sets[n]
|
73
|
+
[train_ids, test_ids]
|
74
|
+
end
|
75
|
+
end
|
76
|
+
end
|
77
|
+
end
|
78
|
+
end
|
@@ -0,0 +1,86 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'rumale/base/splitter'
|
4
|
+
|
5
|
+
module Rumale
|
6
|
+
module ModelSelection
|
7
|
+
# ShuffleSplit is a class that generates the set of data indices for random permutation cross-validation.
|
8
|
+
#
|
9
|
+
# @example
|
10
|
+
# require 'rumale/model_selection/shuffle_split'
|
11
|
+
#
|
12
|
+
# ss = Rumale::ModelSelection::ShuffleSplit.new(n_splits: 3, test_size: 0.2, random_seed: 1)
|
13
|
+
# ss.split(samples, labels).each do |train_ids, test_ids|
|
14
|
+
# train_samples = samples[train_ids, true]
|
15
|
+
# test_samples = samples[test_ids, true]
|
16
|
+
# ...
|
17
|
+
# end
|
18
|
+
#
|
19
|
+
class ShuffleSplit
|
20
|
+
include ::Rumale::Base::Splitter
|
21
|
+
|
22
|
+
# Return the number of folds.
|
23
|
+
# @return [Integer]
|
24
|
+
attr_reader :n_splits
|
25
|
+
|
26
|
+
# Return the random generator for shuffling the dataset.
|
27
|
+
# @return [Random]
|
28
|
+
attr_reader :rng
|
29
|
+
|
30
|
+
# Create a new data splitter for random permutation cross validation.
|
31
|
+
#
|
32
|
+
# @param n_splits [Integer] The number of folds.
|
33
|
+
# @param test_size [Float] The ratio of number of samples for test data.
|
34
|
+
# @param train_size [Float] The ratio of number of samples for train data.
|
35
|
+
# @param random_seed [Integer] The seed value using to initialize the random generator.
|
36
|
+
def initialize(n_splits: 3, test_size: 0.1, train_size: nil, random_seed: nil)
|
37
|
+
@n_splits = n_splits
|
38
|
+
@test_size = test_size
|
39
|
+
@train_size = train_size
|
40
|
+
@random_seed = random_seed
|
41
|
+
@random_seed ||= srand
|
42
|
+
@rng = Random.new(@random_seed)
|
43
|
+
end
|
44
|
+
|
45
|
+
# Generate data indices for random permutation cross validation.
|
46
|
+
#
|
47
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features])
|
48
|
+
# The dataset to be used to generate data indices for random permutation cross validation.
|
49
|
+
# @return [Array] The set of data indices for constructing the training and testing dataset in each fold.
|
50
|
+
def split(x, _y = nil)
|
51
|
+
# Initialize and check some variables.
|
52
|
+
n_samples = x.shape[0]
|
53
|
+
n_test_samples = (@test_size * n_samples).ceil.to_i
|
54
|
+
n_train_samples = @train_size.nil? ? n_samples - n_test_samples : (@train_size * n_samples).floor.to_i
|
55
|
+
unless @n_splits.between?(1, n_samples)
|
56
|
+
raise ArgumentError,
|
57
|
+
'The value of n_splits must be not less than 1 and not more than the number of samples.'
|
58
|
+
end
|
59
|
+
unless n_test_samples.between?(1, n_samples)
|
60
|
+
raise RangeError,
|
61
|
+
'The number of samples in test split must be not less than 1 and not more than the number of samples.'
|
62
|
+
end
|
63
|
+
unless n_train_samples.between?(1, n_samples)
|
64
|
+
raise RangeError,
|
65
|
+
'The number of samples in train split must be not less than 1 and not more than the number of samples.'
|
66
|
+
end
|
67
|
+
if (n_test_samples + n_train_samples) > n_samples
|
68
|
+
raise RangeError,
|
69
|
+
'The total number of samples in test split and train split must be not more than the number of samples.'
|
70
|
+
end
|
71
|
+
sub_rng = @rng.dup
|
72
|
+
# Returns array consisting of the training and testing ids for each fold.
|
73
|
+
dataset_ids = Array(0...n_samples)
|
74
|
+
Array.new(@n_splits) do
|
75
|
+
test_ids = dataset_ids.sample(n_test_samples, random: sub_rng)
|
76
|
+
train_ids = if @train_size.nil?
|
77
|
+
dataset_ids - test_ids
|
78
|
+
else
|
79
|
+
(dataset_ids - test_ids).sample(n_train_samples, random: sub_rng)
|
80
|
+
end
|
81
|
+
[train_ids, test_ids]
|
82
|
+
end
|
83
|
+
end
|
84
|
+
end
|
85
|
+
end
|
86
|
+
end
|
@@ -0,0 +1,95 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'rumale/base/splitter'
|
4
|
+
|
5
|
+
module Rumale
|
6
|
+
module ModelSelection
|
7
|
+
# StratifiedKFold is a class that generates the set of data indices for K-fold cross-validation.
|
8
|
+
# The proportion of the number of samples in each class will be almost equal for each fold.
|
9
|
+
#
|
10
|
+
# @example
|
11
|
+
# require 'rumale/model_selection/stratified_k_fold'
|
12
|
+
#
|
13
|
+
# kf = Rumale::ModelSelection::StratifiedKFold.new(n_splits: 3, shuffle: true, random_seed: 1)
|
14
|
+
# kf.split(samples, labels).each do |train_ids, test_ids|
|
15
|
+
# train_samples = samples[train_ids, true]
|
16
|
+
# test_samples = samples[test_ids, true]
|
17
|
+
# ...
|
18
|
+
# end
|
19
|
+
#
|
20
|
+
class StratifiedKFold
|
21
|
+
include ::Rumale::Base::Splitter
|
22
|
+
|
23
|
+
# Return the number of folds.
|
24
|
+
# @return [Integer]
|
25
|
+
attr_reader :n_splits
|
26
|
+
|
27
|
+
# Return the flag indicating whether to shuffle the dataset.
|
28
|
+
# @return [Boolean]
|
29
|
+
attr_reader :shuffle
|
30
|
+
|
31
|
+
# Return the random generator for shuffling the dataset.
|
32
|
+
# @return [Random]
|
33
|
+
attr_reader :rng
|
34
|
+
|
35
|
+
# Create a new data splitter for stratified K-fold cross validation.
|
36
|
+
#
|
37
|
+
# @param n_splits [Integer] The number of folds.
|
38
|
+
# @param shuffle [Boolean] The flag indicating whether to shuffle the dataset.
|
39
|
+
# @param random_seed [Integer] The seed value using to initialize the random generator.
|
40
|
+
def initialize(n_splits: 3, shuffle: false, random_seed: nil)
|
41
|
+
@n_splits = n_splits
|
42
|
+
@shuffle = shuffle
|
43
|
+
@random_seed = random_seed
|
44
|
+
@random_seed ||= srand
|
45
|
+
@rng = Random.new(@random_seed)
|
46
|
+
end
|
47
|
+
|
48
|
+
# Generate data indices for stratified K-fold cross validation.
|
49
|
+
#
|
50
|
+
# @overload split(x, y) -> Array
|
51
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features])
|
52
|
+
# The dataset to be used to generate data indices for stratified K-fold cross validation.
|
53
|
+
# This argument exists to unify the interface between the K-fold methods, it is not used in the method.
|
54
|
+
# @param y [Numo::Int32] (shape: [n_samples])
|
55
|
+
# The labels to be used to generate data indices for stratified K-fold cross validation.
|
56
|
+
# @return [Array] The set of data indices for constructing the training and testing dataset in each fold.
|
57
|
+
def split(_x, y)
|
58
|
+
# Check the number of samples in each class.
|
59
|
+
unless valid_n_splits?(y)
|
60
|
+
raise ArgumentError,
|
61
|
+
'The value of n_splits must be not less than 2 and not more than the number of samples in each class.'
|
62
|
+
end
|
63
|
+
# Splits dataset ids of each class to each fold.
|
64
|
+
sub_rng = @rng.dup
|
65
|
+
fold_sets_each_class = y.to_a.uniq.map { |label| fold_sets(y, label, sub_rng) }
|
66
|
+
# Returns array consisting of the training and testing ids for each fold.
|
67
|
+
Array.new(@n_splits) { |fold_id| train_test_sets(fold_sets_each_class, fold_id) }
|
68
|
+
end
|
69
|
+
|
70
|
+
private
|
71
|
+
|
72
|
+
def valid_n_splits?(y)
|
73
|
+
y.to_a.uniq.map { |label| y.eq(label).where.size }.all? { |n_samples| @n_splits.between?(2, n_samples) }
|
74
|
+
end
|
75
|
+
|
76
|
+
def fold_sets(y, label, sub_rng)
|
77
|
+
sample_ids = y.eq(label).where.to_a
|
78
|
+
sample_ids.shuffle!(random: sub_rng) if @shuffle
|
79
|
+
n_samples = sample_ids.size
|
80
|
+
Array.new(@n_splits) do |n|
|
81
|
+
n_fold_samples = n_samples / @n_splits
|
82
|
+
n_fold_samples += 1 if n < n_samples % @n_splits
|
83
|
+
sample_ids.shift(n_fold_samples)
|
84
|
+
end
|
85
|
+
end
|
86
|
+
|
87
|
+
def train_test_sets(fold_sets_each_class, fold_id)
|
88
|
+
train_test_sets_each_class = fold_sets_each_class.map do |folds|
|
89
|
+
folds.partition.with_index { |_, id| id != fold_id }.map(&:flatten)
|
90
|
+
end
|
91
|
+
train_test_sets_each_class.transpose.map(&:flatten)
|
92
|
+
end
|
93
|
+
end
|
94
|
+
end
|
95
|
+
end
|
@@ -0,0 +1,115 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'rumale/base/splitter'
|
4
|
+
|
5
|
+
module Rumale
|
6
|
+
module ModelSelection
|
7
|
+
# StratifiedShuffleSplit is a class that generates the set of data indices for random permutation cross-validation.
|
8
|
+
# The proportion of the number of samples in each class will be almost equal for each fold.
|
9
|
+
#
|
10
|
+
# @example
|
11
|
+
# require 'rumale/model_selection/stratified_shuffle_split'
|
12
|
+
#
|
13
|
+
# ss = Rumale::ModelSelection::StratifiedShuffleSplit.new(n_splits: 3, test_size: 0.2, random_seed: 1)
|
14
|
+
# ss.split(samples, labels).each do |train_ids, test_ids|
|
15
|
+
# train_samples = samples[train_ids, true]
|
16
|
+
# test_samples = samples[test_ids, true]
|
17
|
+
# ...
|
18
|
+
# end
|
19
|
+
#
|
20
|
+
class StratifiedShuffleSplit
|
21
|
+
include ::Rumale::Base::Splitter
|
22
|
+
|
23
|
+
# Return the number of folds.
|
24
|
+
# @return [Integer]
|
25
|
+
attr_reader :n_splits
|
26
|
+
|
27
|
+
# Return the random generator for shuffling the dataset.
|
28
|
+
# @return [Random]
|
29
|
+
attr_reader :rng
|
30
|
+
|
31
|
+
# Create a new data splitter for random permutation cross validation.
|
32
|
+
#
|
33
|
+
# @param n_splits [Integer] The number of folds.
|
34
|
+
# @param test_size [Float] The ratio of number of samples for test data.
|
35
|
+
# @param train_size [Float] The ratio of number of samples for train data.
|
36
|
+
# @param random_seed [Integer] The seed value using to initialize the random generator.
|
37
|
+
def initialize(n_splits: 3, test_size: 0.1, train_size: nil, random_seed: nil)
|
38
|
+
@n_splits = n_splits
|
39
|
+
@test_size = test_size
|
40
|
+
@train_size = train_size
|
41
|
+
@random_seed = random_seed
|
42
|
+
@random_seed ||= srand
|
43
|
+
@rng = Random.new(@random_seed)
|
44
|
+
end
|
45
|
+
|
46
|
+
# Generate data indices for stratified random permutation cross validation.
|
47
|
+
#
|
48
|
+
# @overload split(x, y) -> Array
|
49
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features])
|
50
|
+
# The dataset to be used to generate data indices for stratified random permutation cross validation.
|
51
|
+
# This argument exists to unify the interface between the K-fold methods, it is not used in the method.
|
52
|
+
# @param y [Numo::Int32] (shape: [n_samples])
|
53
|
+
# The labels to be used to generate data indices for stratified random permutation cross validation.
|
54
|
+
# @return [Array] The set of data indices for constructing the training and testing dataset in each fold.
|
55
|
+
def split(_x, y)
|
56
|
+
# Initialize and check some variables.
|
57
|
+
train_sz = @train_size.nil? ? 1.0 - @test_size : @train_size
|
58
|
+
sub_rng = @rng.dup
|
59
|
+
# Check the number of samples in each class.
|
60
|
+
unless valid_n_splits?(y)
|
61
|
+
raise ArgumentError,
|
62
|
+
'The value of n_splits must be not less than 1 and not more than the number of samples in each class.'
|
63
|
+
end
|
64
|
+
# rubocop:disable Layout/LineLength
|
65
|
+
unless enough_data_size_each_class?(y, @test_size, 'test')
|
66
|
+
raise RangeError,
|
67
|
+
'The number of samples in test split must be not less than 1 and not more than the number of samples in each class.'
|
68
|
+
end
|
69
|
+
unless enough_data_size_each_class?(y, train_sz, 'train')
|
70
|
+
raise RangeError,
|
71
|
+
'The number of samples in train split must be not less than 1 and not more than the number of samples in each class.'
|
72
|
+
end
|
73
|
+
unless enough_data_size_each_class?(y, train_sz + @test_size, 'train')
|
74
|
+
raise RangeError,
|
75
|
+
'The total number of samples in test split and train split must be not more than the number of samples in each class.'
|
76
|
+
end
|
77
|
+
# rubocop:enable Layout/LineLength
|
78
|
+
# Returns array consisting of the training and testing ids for each fold.
|
79
|
+
sample_ids_each_class = y.to_a.uniq.map { |label| y.eq(label).where.to_a }
|
80
|
+
Array.new(@n_splits) do
|
81
|
+
train_ids = []
|
82
|
+
test_ids = []
|
83
|
+
sample_ids_each_class.each do |sample_ids|
|
84
|
+
n_samples = sample_ids.size
|
85
|
+
n_test_samples = (@test_size * n_samples).ceil.to_i
|
86
|
+
test_ids += sample_ids.sample(n_test_samples, random: sub_rng)
|
87
|
+
train_ids += if @train_size.nil?
|
88
|
+
sample_ids - test_ids
|
89
|
+
else
|
90
|
+
n_train_samples = (train_sz * n_samples).floor.to_i
|
91
|
+
(sample_ids - test_ids).sample(n_train_samples, random: sub_rng)
|
92
|
+
end
|
93
|
+
end
|
94
|
+
[train_ids, test_ids]
|
95
|
+
end
|
96
|
+
end
|
97
|
+
|
98
|
+
private
|
99
|
+
|
100
|
+
def valid_n_splits?(y)
|
101
|
+
y.to_a.uniq.map { |label| y.eq(label).where.size }.all? { |n_samples| @n_splits.between?(1, n_samples) }
|
102
|
+
end
|
103
|
+
|
104
|
+
def enough_data_size_each_class?(y, data_size, data_type)
|
105
|
+
y.to_a.uniq.map { |label| y.eq(label).where.size }.all? do |n_samples|
|
106
|
+
if data_type == 'test'
|
107
|
+
(data_size * n_samples).ceil.to_i.between?(1, n_samples)
|
108
|
+
else
|
109
|
+
(data_size * n_samples).floor.to_i.between?(1, n_samples)
|
110
|
+
end
|
111
|
+
end
|
112
|
+
end
|
113
|
+
end
|
114
|
+
end
|
115
|
+
end
|
@@ -0,0 +1,89 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'rumale/base/splitter'
|
4
|
+
|
5
|
+
module Rumale
|
6
|
+
module ModelSelection
|
7
|
+
# TimeSeriesSplit is a class that generates the set of data indices for time series cross-validation.
|
8
|
+
# It is assumed that the dataset given are already ordered by time information.
|
9
|
+
#
|
10
|
+
# @example
|
11
|
+
# require 'rumale/model_selection/time_series_split'
|
12
|
+
#
|
13
|
+
# cv = Rumale::ModelSelection::TimeSeriesSplit.new(n_splits: 5)
|
14
|
+
# x = Numo::DFloat.new(6, 2).rand
|
15
|
+
# cv.split(x, nil).each do |train_ids, test_ids|
|
16
|
+
# puts '---'
|
17
|
+
# pp train_ids
|
18
|
+
# pp test_ids
|
19
|
+
# end
|
20
|
+
#
|
21
|
+
# # ---
|
22
|
+
# # [0]
|
23
|
+
# # [1]
|
24
|
+
# # ---
|
25
|
+
# # [0, 1]
|
26
|
+
# # [2]
|
27
|
+
# # ---
|
28
|
+
# # [0, 1, 2]
|
29
|
+
# # [3]
|
30
|
+
# # ---
|
31
|
+
# # [0, 1, 2, 3]
|
32
|
+
# # [4]
|
33
|
+
# # ---
|
34
|
+
# # [0, 1, 2, 3, 4]
|
35
|
+
# # [5]
|
36
|
+
#
|
37
|
+
class TimeSeriesSplit
|
38
|
+
include ::Rumale::Base::Splitter
|
39
|
+
|
40
|
+
# Return the number of splits.
|
41
|
+
# @return [Integer]
|
42
|
+
attr_reader :n_splits
|
43
|
+
|
44
|
+
# Return the maximum number of training samples in a split.
|
45
|
+
# @return [Integer/Nil]
|
46
|
+
attr_reader :max_train_size
|
47
|
+
|
48
|
+
# Create a new data splitter for time series cross-validation.
|
49
|
+
#
|
50
|
+
# @param n_splits [Integer] The number of splits.
|
51
|
+
# @param max_train_size [Integer/Nil] The maximum number of training samples in a split.
|
52
|
+
def initialize(n_splits: 5, max_train_size: nil)
|
53
|
+
@n_splits = n_splits
|
54
|
+
@max_train_size = max_train_size
|
55
|
+
end
|
56
|
+
|
57
|
+
# Generate data indices for time series cross-validation.
|
58
|
+
#
|
59
|
+
# @overload split(x, y) -> Array
|
60
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features])
|
61
|
+
# The dataset to be used to generate data indices for time series cross-validation.
|
62
|
+
# It is expected that the data will be ordered by time information.
|
63
|
+
# @param y [Numo::Int32] (shape: [n_samples])
|
64
|
+
# This argument exists to unify the interface between the K-fold methods, it is not used in the method.
|
65
|
+
# @return [Array] The set of data indices for constructing the training and testing dataset in each fold.
|
66
|
+
def split(x, _y)
|
67
|
+
n_samples = x.shape[0]
|
68
|
+
unless (@n_splits + 1).between?(2, n_samples)
|
69
|
+
raise ArgumentError,
|
70
|
+
'The number of folds (n_splits + 1) must be not less than 2 and not more than the number of samples.'
|
71
|
+
end
|
72
|
+
|
73
|
+
test_size = n_samples / (@n_splits + 1)
|
74
|
+
offset = test_size + n_samples % (@n_splits + 1)
|
75
|
+
|
76
|
+
Array.new(@n_splits) do |n|
|
77
|
+
start = offset * (n + 1)
|
78
|
+
train_ids = if !@max_train_size.nil? && @max_train_size < test_size
|
79
|
+
Array((start - @max_train_size)...start)
|
80
|
+
else
|
81
|
+
Array(0...start)
|
82
|
+
end
|
83
|
+
test_ids = Array(start...(start + test_size))
|
84
|
+
[train_ids, test_ids]
|
85
|
+
end
|
86
|
+
end
|
87
|
+
end
|
88
|
+
end
|
89
|
+
end
|
@@ -0,0 +1,16 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'numo/narray'
|
4
|
+
|
5
|
+
require_relative 'model_selection/version'
|
6
|
+
|
7
|
+
require_relative 'model_selection/cross_validation'
|
8
|
+
require_relative 'model_selection/function'
|
9
|
+
require_relative 'model_selection/grid_search_cv'
|
10
|
+
require_relative 'model_selection/group_k_fold'
|
11
|
+
require_relative 'model_selection/group_shuffle_split'
|
12
|
+
require_relative 'model_selection/k_fold'
|
13
|
+
require_relative 'model_selection/shuffle_split'
|
14
|
+
require_relative 'model_selection/stratified_k_fold'
|
15
|
+
require_relative 'model_selection/stratified_shuffle_split'
|
16
|
+
require_relative 'model_selection/time_series_split'
|
metadata
ADDED
@@ -0,0 +1,121 @@
|
|
1
|
+
--- !ruby/object:Gem::Specification
|
2
|
+
name: rumale-model_selection
|
3
|
+
version: !ruby/object:Gem::Version
|
4
|
+
version: 0.24.0
|
5
|
+
platform: ruby
|
6
|
+
authors:
|
7
|
+
- yoshoku
|
8
|
+
autorequire:
|
9
|
+
bindir: exe
|
10
|
+
cert_chain: []
|
11
|
+
date: 2022-12-31 00:00:00.000000000 Z
|
12
|
+
dependencies:
|
13
|
+
- !ruby/object:Gem::Dependency
|
14
|
+
name: numo-narray
|
15
|
+
requirement: !ruby/object:Gem::Requirement
|
16
|
+
requirements:
|
17
|
+
- - ">="
|
18
|
+
- !ruby/object:Gem::Version
|
19
|
+
version: 0.9.1
|
20
|
+
type: :runtime
|
21
|
+
prerelease: false
|
22
|
+
version_requirements: !ruby/object:Gem::Requirement
|
23
|
+
requirements:
|
24
|
+
- - ">="
|
25
|
+
- !ruby/object:Gem::Version
|
26
|
+
version: 0.9.1
|
27
|
+
- !ruby/object:Gem::Dependency
|
28
|
+
name: rumale-core
|
29
|
+
requirement: !ruby/object:Gem::Requirement
|
30
|
+
requirements:
|
31
|
+
- - "~>"
|
32
|
+
- !ruby/object:Gem::Version
|
33
|
+
version: 0.24.0
|
34
|
+
type: :runtime
|
35
|
+
prerelease: false
|
36
|
+
version_requirements: !ruby/object:Gem::Requirement
|
37
|
+
requirements:
|
38
|
+
- - "~>"
|
39
|
+
- !ruby/object:Gem::Version
|
40
|
+
version: 0.24.0
|
41
|
+
- !ruby/object:Gem::Dependency
|
42
|
+
name: rumale-evaluation_measure
|
43
|
+
requirement: !ruby/object:Gem::Requirement
|
44
|
+
requirements:
|
45
|
+
- - "~>"
|
46
|
+
- !ruby/object:Gem::Version
|
47
|
+
version: 0.24.0
|
48
|
+
type: :runtime
|
49
|
+
prerelease: false
|
50
|
+
version_requirements: !ruby/object:Gem::Requirement
|
51
|
+
requirements:
|
52
|
+
- - "~>"
|
53
|
+
- !ruby/object:Gem::Version
|
54
|
+
version: 0.24.0
|
55
|
+
- !ruby/object:Gem::Dependency
|
56
|
+
name: rumale-preprocessing
|
57
|
+
requirement: !ruby/object:Gem::Requirement
|
58
|
+
requirements:
|
59
|
+
- - "~>"
|
60
|
+
- !ruby/object:Gem::Version
|
61
|
+
version: 0.24.0
|
62
|
+
type: :runtime
|
63
|
+
prerelease: false
|
64
|
+
version_requirements: !ruby/object:Gem::Requirement
|
65
|
+
requirements:
|
66
|
+
- - "~>"
|
67
|
+
- !ruby/object:Gem::Version
|
68
|
+
version: 0.24.0
|
69
|
+
description: |
|
70
|
+
Rumale::ModelSelection provides model validation techniques,
|
71
|
+
such as k-fold cross-validation, time series cross-validation, and grid search,
|
72
|
+
with Rumale interface.
|
73
|
+
email:
|
74
|
+
- yoshoku@outlook.com
|
75
|
+
executables: []
|
76
|
+
extensions: []
|
77
|
+
extra_rdoc_files: []
|
78
|
+
files:
|
79
|
+
- LICENSE.txt
|
80
|
+
- README.md
|
81
|
+
- lib/rumale/model_selection.rb
|
82
|
+
- lib/rumale/model_selection/cross_validation.rb
|
83
|
+
- lib/rumale/model_selection/function.rb
|
84
|
+
- lib/rumale/model_selection/grid_search_cv.rb
|
85
|
+
- lib/rumale/model_selection/group_k_fold.rb
|
86
|
+
- lib/rumale/model_selection/group_shuffle_split.rb
|
87
|
+
- lib/rumale/model_selection/k_fold.rb
|
88
|
+
- lib/rumale/model_selection/shuffle_split.rb
|
89
|
+
- lib/rumale/model_selection/stratified_k_fold.rb
|
90
|
+
- lib/rumale/model_selection/stratified_shuffle_split.rb
|
91
|
+
- lib/rumale/model_selection/time_series_split.rb
|
92
|
+
- lib/rumale/model_selection/version.rb
|
93
|
+
homepage: https://github.com/yoshoku/rumale
|
94
|
+
licenses:
|
95
|
+
- BSD-3-Clause
|
96
|
+
metadata:
|
97
|
+
homepage_uri: https://github.com/yoshoku/rumale
|
98
|
+
source_code_uri: https://github.com/yoshoku/rumale/tree/main/rumale-model_selection
|
99
|
+
changelog_uri: https://github.com/yoshoku/rumale/blob/main/CHANGELOG.md
|
100
|
+
documentation_uri: https://yoshoku.github.io/rumale/doc/
|
101
|
+
rubygems_mfa_required: 'true'
|
102
|
+
post_install_message:
|
103
|
+
rdoc_options: []
|
104
|
+
require_paths:
|
105
|
+
- lib
|
106
|
+
required_ruby_version: !ruby/object:Gem::Requirement
|
107
|
+
requirements:
|
108
|
+
- - ">="
|
109
|
+
- !ruby/object:Gem::Version
|
110
|
+
version: '0'
|
111
|
+
required_rubygems_version: !ruby/object:Gem::Requirement
|
112
|
+
requirements:
|
113
|
+
- - ">="
|
114
|
+
- !ruby/object:Gem::Version
|
115
|
+
version: '0'
|
116
|
+
requirements: []
|
117
|
+
rubygems_version: 3.3.26
|
118
|
+
signing_key:
|
119
|
+
specification_version: 4
|
120
|
+
summary: Rumale::ModelSelection provides model validation techniques with Rumale interface.
|
121
|
+
test_files: []
|