rumale-model_selection 0.24.0

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml ADDED
@@ -0,0 +1,7 @@
1
+ ---
2
+ SHA256:
3
+ metadata.gz: 28fa6e0e0c5832366f7130aa8185283b9ae316220b4cebec64774acab7f5cfd6
4
+ data.tar.gz: c85aa131bbcc278f99f2aa91dc6d529c291a2274dbf46c96b468191fefc8015d
5
+ SHA512:
6
+ metadata.gz: ad3df526ec16120843536a9f199ba3f0c5856f40ac5a9de545b34b43bb60a798ad17b1c75588e7cfc59caa34bb886c7f5f180aacb9f24e338d70ebd389bd58aa
7
+ data.tar.gz: 8214f4cec33f056f49d2a479d23f16bc264b92c1e457af695c20d6b5b5233f78c2ab735d6ed584c00ce5333ec9779a8966d6bfd5c96128b7984cda6269b88732
data/LICENSE.txt ADDED
@@ -0,0 +1,27 @@
1
+ Copyright (c) 2022 Atsushi Tatsuma
2
+ All rights reserved.
3
+
4
+ Redistribution and use in source and binary forms, with or without
5
+ modification, are permitted provided that the following conditions are met:
6
+
7
+ * Redistributions of source code must retain the above copyright notice, this
8
+ list of conditions and the following disclaimer.
9
+
10
+ * Redistributions in binary form must reproduce the above copyright notice,
11
+ this list of conditions and the following disclaimer in the documentation
12
+ and/or other materials provided with the distribution.
13
+
14
+ * Neither the name of the copyright holder nor the names of its
15
+ contributors may be used to endorse or promote products derived from
16
+ this software without specific prior written permission.
17
+
18
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
data/README.md ADDED
@@ -0,0 +1,34 @@
1
+ # Rumale::ModelSelection
2
+
3
+ [![Gem Version](https://badge.fury.io/rb/rumale-model_selection.svg)](https://badge.fury.io/rb/rumale-model_selection)
4
+ [![BSD 3-Clause License](https://img.shields.io/badge/License-BSD%203--Clause-orange.svg)](https://github.com/yoshoku/rumale/blob/main/rumale-model_selection/LICENSE.txt)
5
+ [![Documentation](https://img.shields.io/badge/api-reference-blue.svg)](https://yoshoku.github.io/rumale/doc/Rumale/ModelSelection.html)
6
+
7
+ Rumale is a machine learning library in Ruby.
8
+ Rumale::ModelSelection provides model validation techniques,
9
+ such as k-fold cross-validation, time series cross-validation, and grid search,
10
+ with Rumale interface.
11
+
12
+ ## Installation
13
+
14
+ Add this line to your application's Gemfile:
15
+
16
+ ```ruby
17
+ gem 'rumale-model_selection'
18
+ ```
19
+
20
+ And then execute:
21
+
22
+ $ bundle install
23
+
24
+ Or install it yourself as:
25
+
26
+ $ gem install rumale-model_selection
27
+
28
+ ## Documentation
29
+
30
+ - [Rumale API Documentation - ModelSelection](https://yoshoku.github.io/rumale/doc/Rumale/ModelSelection.html)
31
+
32
+ ## License
33
+
34
+ The gem is available as open source under the terms of the [BSD-3-Clause License](https://opensource.org/licenses/BSD-3-Clause).
@@ -0,0 +1,107 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/evaluation_measure/log_loss'
4
+
5
+ module Rumale
6
+ # This module consists of the classes for model validation techniques.
7
+ module ModelSelection
8
+ # CrossValidation is a class that evaluates a given classifier with cross-validation method.
9
+ #
10
+ # @example
11
+ # require 'rumale/linear_model'
12
+ # require 'rumale/model_selection/stratified_k_fold'
13
+ # require 'rumale/model_selection/cross_validation'
14
+ #
15
+ # svc = Rumale::LinearModel::SVC.new
16
+ # kf = Rumale::ModelSelection::StratifiedKFold.new(n_splits: 5)
17
+ # cv = Rumale::ModelSelection::CrossValidation.new(estimator: svc, splitter: kf)
18
+ # report = cv.perform(samples, labels)
19
+ # mean_test_score = report[:test_score].inject(:+) / kf.n_splits
20
+ #
21
+ class CrossValidation
22
+ # Return the classifier of which performance is evaluated.
23
+ # @return [Classifier]
24
+ attr_reader :estimator
25
+
26
+ # Return the splitter that divides dataset.
27
+ # @return [Splitter]
28
+ attr_reader :splitter
29
+
30
+ # Return the evaluator that calculates score.
31
+ # @return [Evaluator]
32
+ attr_reader :evaluator
33
+
34
+ # Return the flag indicating whether to caculate the score of training dataset.
35
+ # @return [Boolean]
36
+ attr_reader :return_train_score
37
+
38
+ # Create a new evaluator with cross-validation method.
39
+ #
40
+ # @param estimator [Classifier] The classifier of which performance is evaluated.
41
+ # @param splitter [Splitter] The splitter that divides dataset to training and testing dataset.
42
+ # @param evaluator [Evaluator] The evaluator that calculates score of estimator results.
43
+ # @param return_train_score [Boolean] The flag indicating whether to calculate the score of training dataset.
44
+ def initialize(estimator: nil, splitter: nil, evaluator: nil, return_train_score: false)
45
+ @estimator = estimator
46
+ @splitter = splitter
47
+ @evaluator = evaluator
48
+ @return_train_score = return_train_score
49
+ end
50
+
51
+ # Perform the evalution of given classifier with cross-validation method.
52
+ #
53
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features])
54
+ # The dataset to be used to evaluate the estimator.
55
+ # @param y [Numo::Int32 / Numo::DFloat] (shape: [n_samples] / [n_samples, n_outputs])
56
+ # The labels to be used to evaluate the classifier / The target values to be used to evaluate the regressor.
57
+ # @return [Hash] The report summarizing the results of cross-validation.
58
+ # * :fit_time (Array<Float>) The calculation times of fitting the estimator for each split.
59
+ # * :test_score (Array<Float>) The scores of testing dataset for each split.
60
+ # * :train_score (Array<Float>) The scores of training dataset for each split. This option is nil if
61
+ # the return_train_score is false.
62
+ def perform(x, y)
63
+ # Initialize the report of cross validation.
64
+ report = { test_score: [], train_score: nil, fit_time: [] }
65
+ report[:train_score] = [] if @return_train_score
66
+ # Evaluate the estimator on each split.
67
+ @splitter.split(x, y).each do |train_ids, test_ids|
68
+ # Split dataset into training and testing dataset.
69
+ feature_ids = !kernel_machine? || train_ids
70
+ train_x = x[train_ids, feature_ids]
71
+ train_y = y.shape[1].nil? ? y[train_ids] : y[train_ids, true]
72
+ test_x = x[test_ids, feature_ids]
73
+ test_y = y.shape[1].nil? ? y[test_ids] : y[test_ids, true]
74
+ # Fit the estimator.
75
+ start_time = Time.now.to_i
76
+ @estimator.fit(train_x, train_y)
77
+ # Calculate scores and prepare the report.
78
+ report[:fit_time].push(Time.now.to_i - start_time)
79
+ if @evaluator.nil?
80
+ report[:test_score].push(@estimator.score(test_x, test_y))
81
+ report[:train_score].push(@estimator.score(train_x, train_y)) if @return_train_score
82
+ elsif log_loss?
83
+ report[:test_score].push(@evaluator.score(test_y, @estimator.predict_proba(test_x)))
84
+ if @return_train_score
85
+ report[:train_score].push(@evaluator.score(train_y,
86
+ @estimator.predict_proba(train_x)))
87
+ end
88
+ else
89
+ report[:test_score].push(@evaluator.score(test_y, @estimator.predict(test_x)))
90
+ report[:train_score].push(@evaluator.score(train_y, @estimator.predict(train_x))) if @return_train_score
91
+ end
92
+ end
93
+ report
94
+ end
95
+
96
+ private
97
+
98
+ def kernel_machine?
99
+ @estimator.class.name.include?('Rumale::KernelMachine')
100
+ end
101
+
102
+ def log_loss?
103
+ @evaluator.is_a?(::Rumale::EvaluationMeasure::LogLoss)
104
+ end
105
+ end
106
+ end
107
+ end
@@ -0,0 +1,47 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'numo/narray'
4
+
5
+ require 'rumale/model_selection/shuffle_split'
6
+ require 'rumale/model_selection/stratified_shuffle_split'
7
+
8
+ module Rumale
9
+ # This module consists of the classes for model validation techniques.
10
+ module ModelSelection
11
+ module_function
12
+
13
+ # Split randomly data set into test and train data.
14
+ #
15
+ # @example
16
+ # require 'rumale/model_selection/function'
17
+ #
18
+ # x_train, x_test, y_train, y_test = Rumale::ModelSelection.train_test_split(x, y, test_size: 0.2, stratify: true, random_seed: 1)
19
+ #
20
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The dataset to be used to generate data indices.
21
+ # @param y [Numo::Int32] (shape: [n_samples]) The labels to be used to generate data indices for stratified random permutation.
22
+ # If stratify = false, this parameter is ignored.
23
+ # @param test_size [Float] The ratio of number of samples for test data.
24
+ # @param train_size [Float] The ratio of number of samples for train data.
25
+ # If nil is given, it sets to 1 - test_size.
26
+ # @param stratify [Boolean] The flag indicating whether to perform stratify split.
27
+ # @param random_seed [Integer] The seed value using to initialize the random generator.
28
+ # @return [Array<Numo::NArray>] The set of training and testing data.
29
+ def train_test_split(x, y = nil, test_size: 0.1, train_size: nil, stratify: false, random_seed: nil)
30
+ splitter = if stratify
31
+ ::Rumale::ModelSelection::StratifiedShuffleSplit.new(
32
+ n_splits: 1, test_size: test_size, train_size: train_size, random_seed: random_seed
33
+ )
34
+ else
35
+ ::Rumale::ModelSelection::ShuffleSplit.new(
36
+ n_splits: 1, test_size: test_size, train_size: train_size, random_seed: random_seed
37
+ )
38
+ end
39
+ train_ids, test_ids = splitter.split(x, y).first
40
+ x_train = x[train_ids, true].dup
41
+ y_train = y[train_ids].dup
42
+ x_test = x[test_ids, true].dup
43
+ y_test = y[test_ids].dup
44
+ [x_train, x_test, y_train, y_test]
45
+ end
46
+ end
47
+ end
@@ -0,0 +1,213 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/estimator'
4
+ require 'rumale/model_selection/cross_validation'
5
+
6
+ module Rumale
7
+ module ModelSelection
8
+ # GridSearchCV is a class that performs hyperparameter optimization with grid search method.
9
+ #
10
+ # @example
11
+ # require 'rumale/ensemble'
12
+ # require 'rumale/model_selection/stratified_k_fold'
13
+ # require 'rumale/model_selection/grid_search_cv'
14
+ #
15
+ # rfc = Rumale::Ensemble::RandomForestClassifier.new(random_seed: 1)
16
+ # pg = { n_estimators: [5, 10], max_depth: [3, 5], max_leaf_nodes: [15, 31] }
17
+ # kf = Rumale::ModelSelection::StratifiedKFold.new(n_splits: 5)
18
+ # gs = Rumale::ModelSelection::GridSearchCV.new(estimator: rfc, param_grid: pg, splitter: kf)
19
+ # gs.fit(samples, labels)
20
+ # p gs.cv_results
21
+ # p gs.best_params
22
+ #
23
+ # @example
24
+ # rbf = Rumale::KernelApproximation::RBF.new(random_seed: 1)
25
+ # svc = Rumale::LinearModel::SVC.new(random_seed: 1)
26
+ # pipe = Rumale::Pipeline::Pipeline.new(steps: { rbf: rbf, svc: svc })
27
+ # pg = { rbf__gamma: [32.0, 1.0], rbf__n_components: [4, 128], svc__reg_param: [16.0, 0.1] }
28
+ # kf = Rumale::ModelSelection::StratifiedKFold.new(n_splits: 5)
29
+ # gs = Rumale::ModelSelection::GridSearchCV.new(estimator: pipe, param_grid: pg, splitter: kf)
30
+ # gs.fit(samples, labels)
31
+ # p gs.cv_results
32
+ # p gs.best_params
33
+ #
34
+ class GridSearchCV < ::Rumale::Base::Estimator
35
+ # Return the result of cross validation for each parameter.
36
+ # @return [Hash]
37
+ attr_reader :cv_results
38
+
39
+ # Return the score of the estimator learned with the best parameter.
40
+ # @return [Float]
41
+ attr_reader :best_score
42
+
43
+ # Return the best parameter set.
44
+ # @return [Hash]
45
+ attr_reader :best_params
46
+
47
+ # Return the index of the best parameter.
48
+ # @return [Integer]
49
+ attr_reader :best_index
50
+
51
+ # Return the estimator learned with the best parameter.
52
+ # @return [Estimator]
53
+ attr_reader :best_estimator
54
+
55
+ # Create a new grid search method.
56
+ #
57
+ # @param estimator [Classifier/Regresor] The estimator to be searched for optimal parameters with grid search method.
58
+ # @param param_grid [Array<Hash>] The parameter sets is represented with array of hash that
59
+ # consists of parameter names as keys and array of parameter values as values.
60
+ # @param splitter [Splitter] The splitter that divides dataset to training and testing dataset on cross validation.
61
+ # @param evaluator [Evaluator] The evaluator that calculates score of estimator results on cross validation.
62
+ # If nil is given, the score method of estimator is used to evaluation.
63
+ # @param greater_is_better [Boolean] The flag that indicates whether the estimator is better as
64
+ # evaluation score is larger.
65
+ def initialize(estimator: nil, param_grid: nil, splitter: nil, evaluator: nil, greater_is_better: true)
66
+ super()
67
+ @params = {
68
+ param_grid: valid_param_grid(param_grid),
69
+ estimator: Marshal.load(Marshal.dump(estimator)),
70
+ splitter: Marshal.load(Marshal.dump(splitter)),
71
+ evaluator: Marshal.load(Marshal.dump(evaluator)),
72
+ greater_is_better: greater_is_better
73
+ }
74
+ end
75
+
76
+ # Fit the model with given training data and all sets of parameters.
77
+ #
78
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
79
+ # @param y [Numo::NArray] (shape: [n_samples, n_outputs]) The target values or labels to be used for fitting the model.
80
+ # @return [GridSearchCV] The learned estimator with grid search.
81
+ def fit(x, y)
82
+ init_attrs
83
+
84
+ param_combinations.each do |prm_set|
85
+ prm_set.each do |prms|
86
+ report = perform_cross_validation(x, y, prms)
87
+ store_cv_result(prms, report)
88
+ end
89
+ end
90
+
91
+ find_best_params
92
+
93
+ @best_estimator = configurated_estimator(@best_params)
94
+ @best_estimator.fit(x, y)
95
+ self
96
+ end
97
+
98
+ # Call the decision_function method of learned estimator with the best parameter.
99
+ #
100
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to compute the scores.
101
+ # @return [Numo::DFloat] (shape: [n_samples]) Confidence score per sample.
102
+ def decision_function(x)
103
+ @best_estimator.decision_function(x)
104
+ end
105
+
106
+ # Call the predict method of learned estimator with the best parameter.
107
+ #
108
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to obtain prediction result.
109
+ # @return [Numo::NArray] Predicted results.
110
+ def predict(x)
111
+ @best_estimator.predict(x)
112
+ end
113
+
114
+ # Call the predict_log_proba method of learned estimator with the best parameter.
115
+ #
116
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the log-probailities.
117
+ # @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted log-probability of each class per sample.
118
+ def predict_log_proba(x)
119
+ @best_estimator.predict_log_proba(x)
120
+ end
121
+
122
+ # Call the predict_proba method of learned estimator with the best parameter.
123
+ #
124
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the probailities.
125
+ # @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted probability of each class per sample.
126
+ def predict_proba(x)
127
+ @best_estimator.predict_proba(x)
128
+ end
129
+
130
+ # Call the score method of learned estimator with the best parameter.
131
+ #
132
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) Testing data.
133
+ # @param y [Numo::NArray] (shape: [n_samples, n_outputs]) True target values or labels for testing data.
134
+ # @return [Float] The score of estimator.
135
+ def score(x, y)
136
+ @best_estimator.score(x, y)
137
+ end
138
+
139
+ private
140
+
141
+ def valid_param_grid(grid)
142
+ raise TypeError, 'Expect class of param_grid to be Hash or Array' unless grid.is_a?(Hash) || grid.is_a?(Array)
143
+
144
+ grid = [grid] if grid.is_a?(Hash)
145
+ grid.each do |h|
146
+ raise TypeError, 'Expect class of elements in param_grid to be Hash' unless h.is_a?(Hash)
147
+ raise TypeError, 'Expect class of parameter values in param_grid to be Array' unless h.values.all?(Array)
148
+ end
149
+ grid
150
+ end
151
+
152
+ def param_combinations
153
+ @param_combinations ||= @params[:param_grid].map do |prm|
154
+ x = prm.sort.to_h.map { |k, v| [k].product(v) }
155
+ x[0].product(*x[1...x.size]).map(&:to_h)
156
+ end
157
+ end
158
+
159
+ def perform_cross_validation(x, y, prms)
160
+ est = configurated_estimator(prms)
161
+ cv = ::Rumale::ModelSelection::CrossValidation.new(estimator: est, splitter: @params[:splitter],
162
+ evaluator: @params[:evaluator], return_train_score: true)
163
+ cv.perform(x, y)
164
+ end
165
+
166
+ def configurated_estimator(prms)
167
+ estimator = Marshal.load(Marshal.dump(@params[:estimator]))
168
+ if pipeline?
169
+ prms.each do |k, v|
170
+ est_name, prm_name = k.to_s.split('__')
171
+ estimator.steps[est_name.to_sym].params[prm_name.to_sym] = v
172
+ end
173
+ else
174
+ prms.each { |k, v| estimator.params[k] = v }
175
+ end
176
+ estimator
177
+ end
178
+
179
+ def init_attrs
180
+ @cv_results = %i[mean_test_score std_test_score
181
+ mean_train_score std_train_score
182
+ mean_fit_time std_fit_time params].to_h { |v| [v, []] }
183
+ @best_score = nil
184
+ @best_params = nil
185
+ @best_index = nil
186
+ @best_estimator = nil
187
+ end
188
+
189
+ def store_cv_result(prms, report)
190
+ test_scores = Numo::DFloat[*report[:test_score]]
191
+ train_scores = Numo::DFloat[*report[:train_score]]
192
+ fit_times = Numo::DFloat[*report[:fit_time]]
193
+ @cv_results[:mean_test_score].push(test_scores.mean)
194
+ @cv_results[:std_test_score].push(test_scores.stddev)
195
+ @cv_results[:mean_train_score].push(train_scores.mean)
196
+ @cv_results[:std_train_score].push(train_scores.stddev)
197
+ @cv_results[:mean_fit_time].push(fit_times.mean)
198
+ @cv_results[:std_fit_time].push(fit_times.stddev)
199
+ @cv_results[:params].push(prms)
200
+ end
201
+
202
+ def find_best_params
203
+ @best_score = @params[:greater_is_better] ? @cv_results[:mean_test_score].max : @cv_results[:mean_test_score].min
204
+ @best_index = @cv_results[:mean_test_score].index(@best_score)
205
+ @best_params = @cv_results[:params][@best_index]
206
+ end
207
+
208
+ def pipeline?
209
+ @params[:estimator].class.name.include?('Rumale::Pipeline')
210
+ end
211
+ end
212
+ end
213
+ end
@@ -0,0 +1,93 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/splitter'
4
+ require 'rumale/preprocessing/label_encoder'
5
+
6
+ module Rumale
7
+ module ModelSelection
8
+ # GroupKFold is a class that generates the set of data indices for K-fold cross-validation.
9
+ # The data points belonging to the same group do not be split into different folds.
10
+ # The number of groups should be greater than or equal to the number of splits.
11
+ #
12
+ # @example
13
+ # require 'rumale/model_selection/group_k_fold'
14
+ #
15
+ # cv = Rumale::ModelSelection::GroupKFold.new(n_splits: 3)
16
+ # x = Numo::DFloat.new(8, 2).rand
17
+ # groups = Numo::Int32[1, 1, 1, 2, 2, 3, 3, 3]
18
+ # cv.split(x, nil, groups).each do |train_ids, test_ids|
19
+ # puts '---'
20
+ # pp train_ids
21
+ # pp test_ids
22
+ # end
23
+ #
24
+ # # ---
25
+ # # [0, 1, 2, 3, 4]
26
+ # # [5, 6, 7]
27
+ # # ---
28
+ # # [3, 4, 5, 6, 7]
29
+ # # [0, 1, 2]
30
+ # # ---
31
+ # # [0, 1, 2, 5, 6, 7]
32
+ # # [3, 4]
33
+ #
34
+ class GroupKFold
35
+ include ::Rumale::Base::Splitter
36
+
37
+ # Return the number of folds.
38
+ # @return [Integer]
39
+ attr_reader :n_splits
40
+
41
+ # Create a new data splitter for grouped K-fold cross validation.
42
+ #
43
+ # @param n_splits [Integer] The number of folds.
44
+ def initialize(n_splits: 5)
45
+ @n_splits = n_splits
46
+ end
47
+
48
+ # Generate data indices for grouped K-fold cross validation.
49
+ #
50
+ # @overload split(x, y, groups) -> Array
51
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features])
52
+ # The dataset to be used to generate data indices for grouped K-fold cross validation.
53
+ # @param y [Numo::Int32] (shape: [n_samples])
54
+ # This argument exists to unify the interface between the K-fold methods, it is not used in the method.
55
+ # @param groups [Numo::Int32] (shape: [n_samples])
56
+ # The group labels to be used to generate data indices for grouped K-fold cross validation.
57
+ # @return [Array] The set of data indices for constructing the training and testing dataset in each fold.
58
+ def split(x, _y, groups)
59
+ encoder = ::Rumale::Preprocessing::LabelEncoder.new
60
+ groups = encoder.fit_transform(groups)
61
+ n_groups = encoder.classes.size
62
+
63
+ if n_groups < @n_splits
64
+ raise ArgumentError,
65
+ 'The number of groups should be greater than or equal to the number of splits.'
66
+ end
67
+
68
+ n_samples_per_group = groups.bincount
69
+ group_ids = n_samples_per_group.sort_index.reverse
70
+ n_samples_per_group = n_samples_per_group[group_ids]
71
+
72
+ n_samples_per_fold = Numo::Int32.zeros(@n_splits)
73
+ group_to_fold = Numo::Int32.zeros(n_groups)
74
+
75
+ n_samples_per_group.each_with_index do |weight, id|
76
+ min_sample_fold_id = n_samples_per_fold.min_index
77
+ n_samples_per_fold[min_sample_fold_id] += weight
78
+ group_to_fold[group_ids[id]] = min_sample_fold_id
79
+ end
80
+
81
+ n_samples = x.shape[0]
82
+ sample_ids = Array(0...n_samples)
83
+ fold_ids = group_to_fold[groups]
84
+
85
+ Array.new(@n_splits) do |fid|
86
+ test_ids = fold_ids.eq(fid).where.to_a
87
+ train_ids = sample_ids - test_ids
88
+ [train_ids, test_ids]
89
+ end
90
+ end
91
+ end
92
+ end
93
+ end
@@ -0,0 +1,108 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/splitter'
4
+
5
+ module Rumale
6
+ module ModelSelection
7
+ # GroupShuffleSplit is a class that generates the set of data indices
8
+ # for random permutation cross-validation by randomly selecting group labels.
9
+ #
10
+ # @example
11
+ # require 'rumale/model_selection/group_shuffle_split'
12
+ #
13
+ # cv = Rumale::ModelSelection::GroupShuffleSplit.new(n_splits: 2, test_size: 0.2, random_seed: 1)
14
+ # x = Numo::DFloat.new(8, 2).rand
15
+ # groups = Numo::Int32[1, 1, 1, 2, 2, 3, 3, 3]
16
+ # cv.split(x, nil, groups).each do |train_ids, test_ids|
17
+ # puts '---'
18
+ # pp train_ids
19
+ # pp test_ids
20
+ # end
21
+ #
22
+ # # ---
23
+ # # [0, 1, 2, 5, 6, 7]
24
+ # # [3, 4]
25
+ # # ---
26
+ # # [3, 4, 5, 6, 7]
27
+ # # [0, 1, 2]
28
+ #
29
+ class GroupShuffleSplit
30
+ include Rumale::Base::Splitter
31
+
32
+ # Return the number of folds.
33
+ # @return [Integer]
34
+ attr_reader :n_splits
35
+
36
+ # Return the random generator for shuffling the dataset.
37
+ # @return [Random]
38
+ attr_reader :rng
39
+
40
+ # Create a new data splitter for random permutation cross validation with given group labels.
41
+ #
42
+ # @param n_splits [Integer] The number of folds.
43
+ # @param test_size [Float] The ratio of number of groups for test data.
44
+ # @param train_size [Float/Nil] The ratio of number of groups for train data.
45
+ # @param random_seed [Integer] The seed value using to initialize the random generator.
46
+ def initialize(n_splits: 5, test_size: 0.2, train_size: nil, random_seed: nil)
47
+ @n_splits = n_splits
48
+ @test_size = test_size
49
+ @train_size = train_size
50
+ @random_seed = random_seed
51
+ @random_seed ||= srand
52
+ @rng = Random.new(@random_seed)
53
+ end
54
+
55
+ # Generate train and test data indices by randomly selecting group labels.
56
+ #
57
+ # @overload split(x, y, groups) -> Array
58
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features])
59
+ # The dataset to be used to generate data indices for random permutation cross validation.
60
+ # @param y [Numo::Int32] (shape: [n_samples])
61
+ # This argument exists to unify the interface between the K-fold methods, it is not used in the method.
62
+ # @param groups [Numo::Int32] (shape: [n_samples])
63
+ # The group labels to be used to generate data indices for random permutation cross validation.
64
+ # @return [Array] The set of data indices for constructing the training and testing dataset in each fold.
65
+ def split(_x, _y, groups)
66
+ classes = groups.to_a.uniq.sort
67
+ n_groups = classes.size
68
+ n_test_groups = (@test_size * n_groups).ceil.to_i
69
+ n_train_groups = @train_size.nil? ? n_groups - n_test_groups : (@train_size * n_groups).floor.to_i
70
+
71
+ unless n_test_groups.between?(1, n_groups)
72
+ raise RangeError,
73
+ 'The number of groups in test split must be not less than 1 and not more than the number of groups.'
74
+ end
75
+ unless n_train_groups.between?(1, n_groups)
76
+ raise RangeError,
77
+ 'The number of groups in train split must be not less than 1 and not more than the number of groups.'
78
+ end
79
+ if (n_test_groups + n_train_groups) > n_groups
80
+ raise RangeError,
81
+ 'The total number of groups in test split and train split must be not more than the number of groups.'
82
+ end
83
+
84
+ sub_rng = @rng.dup
85
+
86
+ Array.new(@n_splits) do
87
+ test_group_ids = classes.sample(n_test_groups, random: sub_rng)
88
+ train_group_ids = if @train_size.nil?
89
+ classes - test_group_ids
90
+ else
91
+ (classes - test_group_ids).sample(n_train_groups, random: sub_rng)
92
+ end
93
+ test_ids = in1d(groups, test_group_ids).where.to_a
94
+ train_ids = in1d(groups, train_group_ids).where.to_a
95
+ [train_ids, test_ids]
96
+ end
97
+ end
98
+
99
+ private
100
+
101
+ def in1d(a, b)
102
+ res = Numo::Bit.zeros(a.shape[0])
103
+ b.each { |v| res |= a.eq(v) }
104
+ res
105
+ end
106
+ end
107
+ end
108
+ end
@@ -0,0 +1,78 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/splitter'
4
+
5
+ module Rumale
6
+ # This module consists of the classes for model validation techniques.
7
+ module ModelSelection
8
+ # KFold is a class that generates the set of data indices for K-fold cross-validation.
9
+ #
10
+ # @example
11
+ # require 'rumale/model_selection/k_fold'
12
+ #
13
+ # kf = Rumale::ModelSelection::KFold.new(n_splits: 3, shuffle: true, random_seed: 1)
14
+ # kf.split(samples, labels).each do |train_ids, test_ids|
15
+ # train_samples = samples[train_ids, true]
16
+ # test_samples = samples[test_ids, true]
17
+ # ...
18
+ # end
19
+ #
20
+ class KFold
21
+ include ::Rumale::Base::Splitter
22
+
23
+ # Return the number of folds.
24
+ # @return [Integer]
25
+ attr_reader :n_splits
26
+
27
+ # Return the flag indicating whether to shuffle the dataset.
28
+ # @return [Boolean]
29
+ attr_reader :shuffle
30
+
31
+ # Return the random generator for shuffling the dataset.
32
+ # @return [Random]
33
+ attr_reader :rng
34
+
35
+ # Create a new data splitter for K-fold cross validation.
36
+ #
37
+ # @param n_splits [Integer] The number of folds.
38
+ # @param shuffle [Boolean] The flag indicating whether to shuffle the dataset.
39
+ # @param random_seed [Integer] The seed value using to initialize the random generator.
40
+ def initialize(n_splits: 3, shuffle: false, random_seed: nil)
41
+ @n_splits = n_splits
42
+ @shuffle = shuffle
43
+ @random_seed = random_seed
44
+ @random_seed ||= srand
45
+ @rng = Random.new(@random_seed)
46
+ end
47
+
48
+ # Generate data indices for K-fold cross validation.
49
+ #
50
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features])
51
+ # The dataset to be used to generate data indices for K-fold cross validation.
52
+ # @return [Array] The set of data indices for constructing the training and testing dataset in each fold.
53
+ def split(x, _y = nil)
54
+ # Initialize and check some variables.
55
+ n_samples, = x.shape
56
+ unless @n_splits.between?(2, n_samples)
57
+ raise ArgumentError,
58
+ 'The value of n_splits must be not less than 2 and not more than the number of samples.'
59
+ end
60
+ sub_rng = @rng.dup
61
+ # Splits dataset ids to each fold.
62
+ dataset_ids = Array(0...n_samples)
63
+ dataset_ids.shuffle!(random: sub_rng) if @shuffle
64
+ fold_sets = Array.new(@n_splits) do |n|
65
+ n_fold_samples = n_samples / @n_splits
66
+ n_fold_samples += 1 if n < n_samples % @n_splits
67
+ dataset_ids.shift(n_fold_samples)
68
+ end
69
+ # Returns array consisting of the training and testing ids for each fold.
70
+ Array.new(@n_splits) do |n|
71
+ train_ids = fold_sets.select.with_index { |_, id| id != n }.flatten
72
+ test_ids = fold_sets[n]
73
+ [train_ids, test_ids]
74
+ end
75
+ end
76
+ end
77
+ end
78
+ end
@@ -0,0 +1,86 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/splitter'
4
+
5
+ module Rumale
6
+ module ModelSelection
7
+ # ShuffleSplit is a class that generates the set of data indices for random permutation cross-validation.
8
+ #
9
+ # @example
10
+ # require 'rumale/model_selection/shuffle_split'
11
+ #
12
+ # ss = Rumale::ModelSelection::ShuffleSplit.new(n_splits: 3, test_size: 0.2, random_seed: 1)
13
+ # ss.split(samples, labels).each do |train_ids, test_ids|
14
+ # train_samples = samples[train_ids, true]
15
+ # test_samples = samples[test_ids, true]
16
+ # ...
17
+ # end
18
+ #
19
+ class ShuffleSplit
20
+ include ::Rumale::Base::Splitter
21
+
22
+ # Return the number of folds.
23
+ # @return [Integer]
24
+ attr_reader :n_splits
25
+
26
+ # Return the random generator for shuffling the dataset.
27
+ # @return [Random]
28
+ attr_reader :rng
29
+
30
+ # Create a new data splitter for random permutation cross validation.
31
+ #
32
+ # @param n_splits [Integer] The number of folds.
33
+ # @param test_size [Float] The ratio of number of samples for test data.
34
+ # @param train_size [Float] The ratio of number of samples for train data.
35
+ # @param random_seed [Integer] The seed value using to initialize the random generator.
36
+ def initialize(n_splits: 3, test_size: 0.1, train_size: nil, random_seed: nil)
37
+ @n_splits = n_splits
38
+ @test_size = test_size
39
+ @train_size = train_size
40
+ @random_seed = random_seed
41
+ @random_seed ||= srand
42
+ @rng = Random.new(@random_seed)
43
+ end
44
+
45
+ # Generate data indices for random permutation cross validation.
46
+ #
47
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features])
48
+ # The dataset to be used to generate data indices for random permutation cross validation.
49
+ # @return [Array] The set of data indices for constructing the training and testing dataset in each fold.
50
+ def split(x, _y = nil)
51
+ # Initialize and check some variables.
52
+ n_samples = x.shape[0]
53
+ n_test_samples = (@test_size * n_samples).ceil.to_i
54
+ n_train_samples = @train_size.nil? ? n_samples - n_test_samples : (@train_size * n_samples).floor.to_i
55
+ unless @n_splits.between?(1, n_samples)
56
+ raise ArgumentError,
57
+ 'The value of n_splits must be not less than 1 and not more than the number of samples.'
58
+ end
59
+ unless n_test_samples.between?(1, n_samples)
60
+ raise RangeError,
61
+ 'The number of samples in test split must be not less than 1 and not more than the number of samples.'
62
+ end
63
+ unless n_train_samples.between?(1, n_samples)
64
+ raise RangeError,
65
+ 'The number of samples in train split must be not less than 1 and not more than the number of samples.'
66
+ end
67
+ if (n_test_samples + n_train_samples) > n_samples
68
+ raise RangeError,
69
+ 'The total number of samples in test split and train split must be not more than the number of samples.'
70
+ end
71
+ sub_rng = @rng.dup
72
+ # Returns array consisting of the training and testing ids for each fold.
73
+ dataset_ids = Array(0...n_samples)
74
+ Array.new(@n_splits) do
75
+ test_ids = dataset_ids.sample(n_test_samples, random: sub_rng)
76
+ train_ids = if @train_size.nil?
77
+ dataset_ids - test_ids
78
+ else
79
+ (dataset_ids - test_ids).sample(n_train_samples, random: sub_rng)
80
+ end
81
+ [train_ids, test_ids]
82
+ end
83
+ end
84
+ end
85
+ end
86
+ end
@@ -0,0 +1,95 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/splitter'
4
+
5
+ module Rumale
6
+ module ModelSelection
7
+ # StratifiedKFold is a class that generates the set of data indices for K-fold cross-validation.
8
+ # The proportion of the number of samples in each class will be almost equal for each fold.
9
+ #
10
+ # @example
11
+ # require 'rumale/model_selection/stratified_k_fold'
12
+ #
13
+ # kf = Rumale::ModelSelection::StratifiedKFold.new(n_splits: 3, shuffle: true, random_seed: 1)
14
+ # kf.split(samples, labels).each do |train_ids, test_ids|
15
+ # train_samples = samples[train_ids, true]
16
+ # test_samples = samples[test_ids, true]
17
+ # ...
18
+ # end
19
+ #
20
+ class StratifiedKFold
21
+ include ::Rumale::Base::Splitter
22
+
23
+ # Return the number of folds.
24
+ # @return [Integer]
25
+ attr_reader :n_splits
26
+
27
+ # Return the flag indicating whether to shuffle the dataset.
28
+ # @return [Boolean]
29
+ attr_reader :shuffle
30
+
31
+ # Return the random generator for shuffling the dataset.
32
+ # @return [Random]
33
+ attr_reader :rng
34
+
35
+ # Create a new data splitter for stratified K-fold cross validation.
36
+ #
37
+ # @param n_splits [Integer] The number of folds.
38
+ # @param shuffle [Boolean] The flag indicating whether to shuffle the dataset.
39
+ # @param random_seed [Integer] The seed value using to initialize the random generator.
40
+ def initialize(n_splits: 3, shuffle: false, random_seed: nil)
41
+ @n_splits = n_splits
42
+ @shuffle = shuffle
43
+ @random_seed = random_seed
44
+ @random_seed ||= srand
45
+ @rng = Random.new(@random_seed)
46
+ end
47
+
48
+ # Generate data indices for stratified K-fold cross validation.
49
+ #
50
+ # @overload split(x, y) -> Array
51
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features])
52
+ # The dataset to be used to generate data indices for stratified K-fold cross validation.
53
+ # This argument exists to unify the interface between the K-fold methods, it is not used in the method.
54
+ # @param y [Numo::Int32] (shape: [n_samples])
55
+ # The labels to be used to generate data indices for stratified K-fold cross validation.
56
+ # @return [Array] The set of data indices for constructing the training and testing dataset in each fold.
57
+ def split(_x, y)
58
+ # Check the number of samples in each class.
59
+ unless valid_n_splits?(y)
60
+ raise ArgumentError,
61
+ 'The value of n_splits must be not less than 2 and not more than the number of samples in each class.'
62
+ end
63
+ # Splits dataset ids of each class to each fold.
64
+ sub_rng = @rng.dup
65
+ fold_sets_each_class = y.to_a.uniq.map { |label| fold_sets(y, label, sub_rng) }
66
+ # Returns array consisting of the training and testing ids for each fold.
67
+ Array.new(@n_splits) { |fold_id| train_test_sets(fold_sets_each_class, fold_id) }
68
+ end
69
+
70
+ private
71
+
72
+ def valid_n_splits?(y)
73
+ y.to_a.uniq.map { |label| y.eq(label).where.size }.all? { |n_samples| @n_splits.between?(2, n_samples) }
74
+ end
75
+
76
+ def fold_sets(y, label, sub_rng)
77
+ sample_ids = y.eq(label).where.to_a
78
+ sample_ids.shuffle!(random: sub_rng) if @shuffle
79
+ n_samples = sample_ids.size
80
+ Array.new(@n_splits) do |n|
81
+ n_fold_samples = n_samples / @n_splits
82
+ n_fold_samples += 1 if n < n_samples % @n_splits
83
+ sample_ids.shift(n_fold_samples)
84
+ end
85
+ end
86
+
87
+ def train_test_sets(fold_sets_each_class, fold_id)
88
+ train_test_sets_each_class = fold_sets_each_class.map do |folds|
89
+ folds.partition.with_index { |_, id| id != fold_id }.map(&:flatten)
90
+ end
91
+ train_test_sets_each_class.transpose.map(&:flatten)
92
+ end
93
+ end
94
+ end
95
+ end
@@ -0,0 +1,115 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/splitter'
4
+
5
+ module Rumale
6
+ module ModelSelection
7
+ # StratifiedShuffleSplit is a class that generates the set of data indices for random permutation cross-validation.
8
+ # The proportion of the number of samples in each class will be almost equal for each fold.
9
+ #
10
+ # @example
11
+ # require 'rumale/model_selection/stratified_shuffle_split'
12
+ #
13
+ # ss = Rumale::ModelSelection::StratifiedShuffleSplit.new(n_splits: 3, test_size: 0.2, random_seed: 1)
14
+ # ss.split(samples, labels).each do |train_ids, test_ids|
15
+ # train_samples = samples[train_ids, true]
16
+ # test_samples = samples[test_ids, true]
17
+ # ...
18
+ # end
19
+ #
20
+ class StratifiedShuffleSplit
21
+ include ::Rumale::Base::Splitter
22
+
23
+ # Return the number of folds.
24
+ # @return [Integer]
25
+ attr_reader :n_splits
26
+
27
+ # Return the random generator for shuffling the dataset.
28
+ # @return [Random]
29
+ attr_reader :rng
30
+
31
+ # Create a new data splitter for random permutation cross validation.
32
+ #
33
+ # @param n_splits [Integer] The number of folds.
34
+ # @param test_size [Float] The ratio of number of samples for test data.
35
+ # @param train_size [Float] The ratio of number of samples for train data.
36
+ # @param random_seed [Integer] The seed value using to initialize the random generator.
37
+ def initialize(n_splits: 3, test_size: 0.1, train_size: nil, random_seed: nil)
38
+ @n_splits = n_splits
39
+ @test_size = test_size
40
+ @train_size = train_size
41
+ @random_seed = random_seed
42
+ @random_seed ||= srand
43
+ @rng = Random.new(@random_seed)
44
+ end
45
+
46
+ # Generate data indices for stratified random permutation cross validation.
47
+ #
48
+ # @overload split(x, y) -> Array
49
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features])
50
+ # The dataset to be used to generate data indices for stratified random permutation cross validation.
51
+ # This argument exists to unify the interface between the K-fold methods, it is not used in the method.
52
+ # @param y [Numo::Int32] (shape: [n_samples])
53
+ # The labels to be used to generate data indices for stratified random permutation cross validation.
54
+ # @return [Array] The set of data indices for constructing the training and testing dataset in each fold.
55
+ def split(_x, y)
56
+ # Initialize and check some variables.
57
+ train_sz = @train_size.nil? ? 1.0 - @test_size : @train_size
58
+ sub_rng = @rng.dup
59
+ # Check the number of samples in each class.
60
+ unless valid_n_splits?(y)
61
+ raise ArgumentError,
62
+ 'The value of n_splits must be not less than 1 and not more than the number of samples in each class.'
63
+ end
64
+ # rubocop:disable Layout/LineLength
65
+ unless enough_data_size_each_class?(y, @test_size, 'test')
66
+ raise RangeError,
67
+ 'The number of samples in test split must be not less than 1 and not more than the number of samples in each class.'
68
+ end
69
+ unless enough_data_size_each_class?(y, train_sz, 'train')
70
+ raise RangeError,
71
+ 'The number of samples in train split must be not less than 1 and not more than the number of samples in each class.'
72
+ end
73
+ unless enough_data_size_each_class?(y, train_sz + @test_size, 'train')
74
+ raise RangeError,
75
+ 'The total number of samples in test split and train split must be not more than the number of samples in each class.'
76
+ end
77
+ # rubocop:enable Layout/LineLength
78
+ # Returns array consisting of the training and testing ids for each fold.
79
+ sample_ids_each_class = y.to_a.uniq.map { |label| y.eq(label).where.to_a }
80
+ Array.new(@n_splits) do
81
+ train_ids = []
82
+ test_ids = []
83
+ sample_ids_each_class.each do |sample_ids|
84
+ n_samples = sample_ids.size
85
+ n_test_samples = (@test_size * n_samples).ceil.to_i
86
+ test_ids += sample_ids.sample(n_test_samples, random: sub_rng)
87
+ train_ids += if @train_size.nil?
88
+ sample_ids - test_ids
89
+ else
90
+ n_train_samples = (train_sz * n_samples).floor.to_i
91
+ (sample_ids - test_ids).sample(n_train_samples, random: sub_rng)
92
+ end
93
+ end
94
+ [train_ids, test_ids]
95
+ end
96
+ end
97
+
98
+ private
99
+
100
+ def valid_n_splits?(y)
101
+ y.to_a.uniq.map { |label| y.eq(label).where.size }.all? { |n_samples| @n_splits.between?(1, n_samples) }
102
+ end
103
+
104
+ def enough_data_size_each_class?(y, data_size, data_type)
105
+ y.to_a.uniq.map { |label| y.eq(label).where.size }.all? do |n_samples|
106
+ if data_type == 'test'
107
+ (data_size * n_samples).ceil.to_i.between?(1, n_samples)
108
+ else
109
+ (data_size * n_samples).floor.to_i.between?(1, n_samples)
110
+ end
111
+ end
112
+ end
113
+ end
114
+ end
115
+ end
@@ -0,0 +1,89 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/splitter'
4
+
5
+ module Rumale
6
+ module ModelSelection
7
+ # TimeSeriesSplit is a class that generates the set of data indices for time series cross-validation.
8
+ # It is assumed that the dataset given are already ordered by time information.
9
+ #
10
+ # @example
11
+ # require 'rumale/model_selection/time_series_split'
12
+ #
13
+ # cv = Rumale::ModelSelection::TimeSeriesSplit.new(n_splits: 5)
14
+ # x = Numo::DFloat.new(6, 2).rand
15
+ # cv.split(x, nil).each do |train_ids, test_ids|
16
+ # puts '---'
17
+ # pp train_ids
18
+ # pp test_ids
19
+ # end
20
+ #
21
+ # # ---
22
+ # # [0]
23
+ # # [1]
24
+ # # ---
25
+ # # [0, 1]
26
+ # # [2]
27
+ # # ---
28
+ # # [0, 1, 2]
29
+ # # [3]
30
+ # # ---
31
+ # # [0, 1, 2, 3]
32
+ # # [4]
33
+ # # ---
34
+ # # [0, 1, 2, 3, 4]
35
+ # # [5]
36
+ #
37
+ class TimeSeriesSplit
38
+ include ::Rumale::Base::Splitter
39
+
40
+ # Return the number of splits.
41
+ # @return [Integer]
42
+ attr_reader :n_splits
43
+
44
+ # Return the maximum number of training samples in a split.
45
+ # @return [Integer/Nil]
46
+ attr_reader :max_train_size
47
+
48
+ # Create a new data splitter for time series cross-validation.
49
+ #
50
+ # @param n_splits [Integer] The number of splits.
51
+ # @param max_train_size [Integer/Nil] The maximum number of training samples in a split.
52
+ def initialize(n_splits: 5, max_train_size: nil)
53
+ @n_splits = n_splits
54
+ @max_train_size = max_train_size
55
+ end
56
+
57
+ # Generate data indices for time series cross-validation.
58
+ #
59
+ # @overload split(x, y) -> Array
60
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features])
61
+ # The dataset to be used to generate data indices for time series cross-validation.
62
+ # It is expected that the data will be ordered by time information.
63
+ # @param y [Numo::Int32] (shape: [n_samples])
64
+ # This argument exists to unify the interface between the K-fold methods, it is not used in the method.
65
+ # @return [Array] The set of data indices for constructing the training and testing dataset in each fold.
66
+ def split(x, _y)
67
+ n_samples = x.shape[0]
68
+ unless (@n_splits + 1).between?(2, n_samples)
69
+ raise ArgumentError,
70
+ 'The number of folds (n_splits + 1) must be not less than 2 and not more than the number of samples.'
71
+ end
72
+
73
+ test_size = n_samples / (@n_splits + 1)
74
+ offset = test_size + n_samples % (@n_splits + 1)
75
+
76
+ Array.new(@n_splits) do |n|
77
+ start = offset * (n + 1)
78
+ train_ids = if !@max_train_size.nil? && @max_train_size < test_size
79
+ Array((start - @max_train_size)...start)
80
+ else
81
+ Array(0...start)
82
+ end
83
+ test_ids = Array(start...(start + test_size))
84
+ [train_ids, test_ids]
85
+ end
86
+ end
87
+ end
88
+ end
89
+ end
@@ -0,0 +1,8 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Rumale
4
+ module ModelSelection
5
+ # @!visibility private
6
+ VERSION = '0.24.0'
7
+ end
8
+ end
@@ -0,0 +1,16 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'numo/narray'
4
+
5
+ require_relative 'model_selection/version'
6
+
7
+ require_relative 'model_selection/cross_validation'
8
+ require_relative 'model_selection/function'
9
+ require_relative 'model_selection/grid_search_cv'
10
+ require_relative 'model_selection/group_k_fold'
11
+ require_relative 'model_selection/group_shuffle_split'
12
+ require_relative 'model_selection/k_fold'
13
+ require_relative 'model_selection/shuffle_split'
14
+ require_relative 'model_selection/stratified_k_fold'
15
+ require_relative 'model_selection/stratified_shuffle_split'
16
+ require_relative 'model_selection/time_series_split'
metadata ADDED
@@ -0,0 +1,121 @@
1
+ --- !ruby/object:Gem::Specification
2
+ name: rumale-model_selection
3
+ version: !ruby/object:Gem::Version
4
+ version: 0.24.0
5
+ platform: ruby
6
+ authors:
7
+ - yoshoku
8
+ autorequire:
9
+ bindir: exe
10
+ cert_chain: []
11
+ date: 2022-12-31 00:00:00.000000000 Z
12
+ dependencies:
13
+ - !ruby/object:Gem::Dependency
14
+ name: numo-narray
15
+ requirement: !ruby/object:Gem::Requirement
16
+ requirements:
17
+ - - ">="
18
+ - !ruby/object:Gem::Version
19
+ version: 0.9.1
20
+ type: :runtime
21
+ prerelease: false
22
+ version_requirements: !ruby/object:Gem::Requirement
23
+ requirements:
24
+ - - ">="
25
+ - !ruby/object:Gem::Version
26
+ version: 0.9.1
27
+ - !ruby/object:Gem::Dependency
28
+ name: rumale-core
29
+ requirement: !ruby/object:Gem::Requirement
30
+ requirements:
31
+ - - "~>"
32
+ - !ruby/object:Gem::Version
33
+ version: 0.24.0
34
+ type: :runtime
35
+ prerelease: false
36
+ version_requirements: !ruby/object:Gem::Requirement
37
+ requirements:
38
+ - - "~>"
39
+ - !ruby/object:Gem::Version
40
+ version: 0.24.0
41
+ - !ruby/object:Gem::Dependency
42
+ name: rumale-evaluation_measure
43
+ requirement: !ruby/object:Gem::Requirement
44
+ requirements:
45
+ - - "~>"
46
+ - !ruby/object:Gem::Version
47
+ version: 0.24.0
48
+ type: :runtime
49
+ prerelease: false
50
+ version_requirements: !ruby/object:Gem::Requirement
51
+ requirements:
52
+ - - "~>"
53
+ - !ruby/object:Gem::Version
54
+ version: 0.24.0
55
+ - !ruby/object:Gem::Dependency
56
+ name: rumale-preprocessing
57
+ requirement: !ruby/object:Gem::Requirement
58
+ requirements:
59
+ - - "~>"
60
+ - !ruby/object:Gem::Version
61
+ version: 0.24.0
62
+ type: :runtime
63
+ prerelease: false
64
+ version_requirements: !ruby/object:Gem::Requirement
65
+ requirements:
66
+ - - "~>"
67
+ - !ruby/object:Gem::Version
68
+ version: 0.24.0
69
+ description: |
70
+ Rumale::ModelSelection provides model validation techniques,
71
+ such as k-fold cross-validation, time series cross-validation, and grid search,
72
+ with Rumale interface.
73
+ email:
74
+ - yoshoku@outlook.com
75
+ executables: []
76
+ extensions: []
77
+ extra_rdoc_files: []
78
+ files:
79
+ - LICENSE.txt
80
+ - README.md
81
+ - lib/rumale/model_selection.rb
82
+ - lib/rumale/model_selection/cross_validation.rb
83
+ - lib/rumale/model_selection/function.rb
84
+ - lib/rumale/model_selection/grid_search_cv.rb
85
+ - lib/rumale/model_selection/group_k_fold.rb
86
+ - lib/rumale/model_selection/group_shuffle_split.rb
87
+ - lib/rumale/model_selection/k_fold.rb
88
+ - lib/rumale/model_selection/shuffle_split.rb
89
+ - lib/rumale/model_selection/stratified_k_fold.rb
90
+ - lib/rumale/model_selection/stratified_shuffle_split.rb
91
+ - lib/rumale/model_selection/time_series_split.rb
92
+ - lib/rumale/model_selection/version.rb
93
+ homepage: https://github.com/yoshoku/rumale
94
+ licenses:
95
+ - BSD-3-Clause
96
+ metadata:
97
+ homepage_uri: https://github.com/yoshoku/rumale
98
+ source_code_uri: https://github.com/yoshoku/rumale/tree/main/rumale-model_selection
99
+ changelog_uri: https://github.com/yoshoku/rumale/blob/main/CHANGELOG.md
100
+ documentation_uri: https://yoshoku.github.io/rumale/doc/
101
+ rubygems_mfa_required: 'true'
102
+ post_install_message:
103
+ rdoc_options: []
104
+ require_paths:
105
+ - lib
106
+ required_ruby_version: !ruby/object:Gem::Requirement
107
+ requirements:
108
+ - - ">="
109
+ - !ruby/object:Gem::Version
110
+ version: '0'
111
+ required_rubygems_version: !ruby/object:Gem::Requirement
112
+ requirements:
113
+ - - ">="
114
+ - !ruby/object:Gem::Version
115
+ version: '0'
116
+ requirements: []
117
+ rubygems_version: 3.3.26
118
+ signing_key:
119
+ specification_version: 4
120
+ summary: Rumale::ModelSelection provides model validation techniques with Rumale interface.
121
+ test_files: []