rumale-model_selection 0.24.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml ADDED
@@ -0,0 +1,7 @@
1
+ ---
2
+ SHA256:
3
+ metadata.gz: 28fa6e0e0c5832366f7130aa8185283b9ae316220b4cebec64774acab7f5cfd6
4
+ data.tar.gz: c85aa131bbcc278f99f2aa91dc6d529c291a2274dbf46c96b468191fefc8015d
5
+ SHA512:
6
+ metadata.gz: ad3df526ec16120843536a9f199ba3f0c5856f40ac5a9de545b34b43bb60a798ad17b1c75588e7cfc59caa34bb886c7f5f180aacb9f24e338d70ebd389bd58aa
7
+ data.tar.gz: 8214f4cec33f056f49d2a479d23f16bc264b92c1e457af695c20d6b5b5233f78c2ab735d6ed584c00ce5333ec9779a8966d6bfd5c96128b7984cda6269b88732
data/LICENSE.txt ADDED
@@ -0,0 +1,27 @@
1
+ Copyright (c) 2022 Atsushi Tatsuma
2
+ All rights reserved.
3
+
4
+ Redistribution and use in source and binary forms, with or without
5
+ modification, are permitted provided that the following conditions are met:
6
+
7
+ * Redistributions of source code must retain the above copyright notice, this
8
+ list of conditions and the following disclaimer.
9
+
10
+ * Redistributions in binary form must reproduce the above copyright notice,
11
+ this list of conditions and the following disclaimer in the documentation
12
+ and/or other materials provided with the distribution.
13
+
14
+ * Neither the name of the copyright holder nor the names of its
15
+ contributors may be used to endorse or promote products derived from
16
+ this software without specific prior written permission.
17
+
18
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
data/README.md ADDED
@@ -0,0 +1,34 @@
1
+ # Rumale::ModelSelection
2
+
3
+ [![Gem Version](https://badge.fury.io/rb/rumale-model_selection.svg)](https://badge.fury.io/rb/rumale-model_selection)
4
+ [![BSD 3-Clause License](https://img.shields.io/badge/License-BSD%203--Clause-orange.svg)](https://github.com/yoshoku/rumale/blob/main/rumale-model_selection/LICENSE.txt)
5
+ [![Documentation](https://img.shields.io/badge/api-reference-blue.svg)](https://yoshoku.github.io/rumale/doc/Rumale/ModelSelection.html)
6
+
7
+ Rumale is a machine learning library in Ruby.
8
+ Rumale::ModelSelection provides model validation techniques,
9
+ such as k-fold cross-validation, time series cross-validation, and grid search,
10
+ with Rumale interface.
11
+
12
+ ## Installation
13
+
14
+ Add this line to your application's Gemfile:
15
+
16
+ ```ruby
17
+ gem 'rumale-model_selection'
18
+ ```
19
+
20
+ And then execute:
21
+
22
+ $ bundle install
23
+
24
+ Or install it yourself as:
25
+
26
+ $ gem install rumale-model_selection
27
+
28
+ ## Documentation
29
+
30
+ - [Rumale API Documentation - ModelSelection](https://yoshoku.github.io/rumale/doc/Rumale/ModelSelection.html)
31
+
32
+ ## License
33
+
34
+ The gem is available as open source under the terms of the [BSD-3-Clause License](https://opensource.org/licenses/BSD-3-Clause).
@@ -0,0 +1,107 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/evaluation_measure/log_loss'
4
+
5
+ module Rumale
6
+ # This module consists of the classes for model validation techniques.
7
+ module ModelSelection
8
+ # CrossValidation is a class that evaluates a given classifier with cross-validation method.
9
+ #
10
+ # @example
11
+ # require 'rumale/linear_model'
12
+ # require 'rumale/model_selection/stratified_k_fold'
13
+ # require 'rumale/model_selection/cross_validation'
14
+ #
15
+ # svc = Rumale::LinearModel::SVC.new
16
+ # kf = Rumale::ModelSelection::StratifiedKFold.new(n_splits: 5)
17
+ # cv = Rumale::ModelSelection::CrossValidation.new(estimator: svc, splitter: kf)
18
+ # report = cv.perform(samples, labels)
19
+ # mean_test_score = report[:test_score].inject(:+) / kf.n_splits
20
+ #
21
+ class CrossValidation
22
+ # Return the classifier of which performance is evaluated.
23
+ # @return [Classifier]
24
+ attr_reader :estimator
25
+
26
+ # Return the splitter that divides dataset.
27
+ # @return [Splitter]
28
+ attr_reader :splitter
29
+
30
+ # Return the evaluator that calculates score.
31
+ # @return [Evaluator]
32
+ attr_reader :evaluator
33
+
34
+ # Return the flag indicating whether to caculate the score of training dataset.
35
+ # @return [Boolean]
36
+ attr_reader :return_train_score
37
+
38
+ # Create a new evaluator with cross-validation method.
39
+ #
40
+ # @param estimator [Classifier] The classifier of which performance is evaluated.
41
+ # @param splitter [Splitter] The splitter that divides dataset to training and testing dataset.
42
+ # @param evaluator [Evaluator] The evaluator that calculates score of estimator results.
43
+ # @param return_train_score [Boolean] The flag indicating whether to calculate the score of training dataset.
44
+ def initialize(estimator: nil, splitter: nil, evaluator: nil, return_train_score: false)
45
+ @estimator = estimator
46
+ @splitter = splitter
47
+ @evaluator = evaluator
48
+ @return_train_score = return_train_score
49
+ end
50
+
51
+ # Perform the evalution of given classifier with cross-validation method.
52
+ #
53
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features])
54
+ # The dataset to be used to evaluate the estimator.
55
+ # @param y [Numo::Int32 / Numo::DFloat] (shape: [n_samples] / [n_samples, n_outputs])
56
+ # The labels to be used to evaluate the classifier / The target values to be used to evaluate the regressor.
57
+ # @return [Hash] The report summarizing the results of cross-validation.
58
+ # * :fit_time (Array<Float>) The calculation times of fitting the estimator for each split.
59
+ # * :test_score (Array<Float>) The scores of testing dataset for each split.
60
+ # * :train_score (Array<Float>) The scores of training dataset for each split. This option is nil if
61
+ # the return_train_score is false.
62
+ def perform(x, y)
63
+ # Initialize the report of cross validation.
64
+ report = { test_score: [], train_score: nil, fit_time: [] }
65
+ report[:train_score] = [] if @return_train_score
66
+ # Evaluate the estimator on each split.
67
+ @splitter.split(x, y).each do |train_ids, test_ids|
68
+ # Split dataset into training and testing dataset.
69
+ feature_ids = !kernel_machine? || train_ids
70
+ train_x = x[train_ids, feature_ids]
71
+ train_y = y.shape[1].nil? ? y[train_ids] : y[train_ids, true]
72
+ test_x = x[test_ids, feature_ids]
73
+ test_y = y.shape[1].nil? ? y[test_ids] : y[test_ids, true]
74
+ # Fit the estimator.
75
+ start_time = Time.now.to_i
76
+ @estimator.fit(train_x, train_y)
77
+ # Calculate scores and prepare the report.
78
+ report[:fit_time].push(Time.now.to_i - start_time)
79
+ if @evaluator.nil?
80
+ report[:test_score].push(@estimator.score(test_x, test_y))
81
+ report[:train_score].push(@estimator.score(train_x, train_y)) if @return_train_score
82
+ elsif log_loss?
83
+ report[:test_score].push(@evaluator.score(test_y, @estimator.predict_proba(test_x)))
84
+ if @return_train_score
85
+ report[:train_score].push(@evaluator.score(train_y,
86
+ @estimator.predict_proba(train_x)))
87
+ end
88
+ else
89
+ report[:test_score].push(@evaluator.score(test_y, @estimator.predict(test_x)))
90
+ report[:train_score].push(@evaluator.score(train_y, @estimator.predict(train_x))) if @return_train_score
91
+ end
92
+ end
93
+ report
94
+ end
95
+
96
+ private
97
+
98
+ def kernel_machine?
99
+ @estimator.class.name.include?('Rumale::KernelMachine')
100
+ end
101
+
102
+ def log_loss?
103
+ @evaluator.is_a?(::Rumale::EvaluationMeasure::LogLoss)
104
+ end
105
+ end
106
+ end
107
+ end
@@ -0,0 +1,47 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'numo/narray'
4
+
5
+ require 'rumale/model_selection/shuffle_split'
6
+ require 'rumale/model_selection/stratified_shuffle_split'
7
+
8
+ module Rumale
9
+ # This module consists of the classes for model validation techniques.
10
+ module ModelSelection
11
+ module_function
12
+
13
+ # Split randomly data set into test and train data.
14
+ #
15
+ # @example
16
+ # require 'rumale/model_selection/function'
17
+ #
18
+ # x_train, x_test, y_train, y_test = Rumale::ModelSelection.train_test_split(x, y, test_size: 0.2, stratify: true, random_seed: 1)
19
+ #
20
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The dataset to be used to generate data indices.
21
+ # @param y [Numo::Int32] (shape: [n_samples]) The labels to be used to generate data indices for stratified random permutation.
22
+ # If stratify = false, this parameter is ignored.
23
+ # @param test_size [Float] The ratio of number of samples for test data.
24
+ # @param train_size [Float] The ratio of number of samples for train data.
25
+ # If nil is given, it sets to 1 - test_size.
26
+ # @param stratify [Boolean] The flag indicating whether to perform stratify split.
27
+ # @param random_seed [Integer] The seed value using to initialize the random generator.
28
+ # @return [Array<Numo::NArray>] The set of training and testing data.
29
+ def train_test_split(x, y = nil, test_size: 0.1, train_size: nil, stratify: false, random_seed: nil)
30
+ splitter = if stratify
31
+ ::Rumale::ModelSelection::StratifiedShuffleSplit.new(
32
+ n_splits: 1, test_size: test_size, train_size: train_size, random_seed: random_seed
33
+ )
34
+ else
35
+ ::Rumale::ModelSelection::ShuffleSplit.new(
36
+ n_splits: 1, test_size: test_size, train_size: train_size, random_seed: random_seed
37
+ )
38
+ end
39
+ train_ids, test_ids = splitter.split(x, y).first
40
+ x_train = x[train_ids, true].dup
41
+ y_train = y[train_ids].dup
42
+ x_test = x[test_ids, true].dup
43
+ y_test = y[test_ids].dup
44
+ [x_train, x_test, y_train, y_test]
45
+ end
46
+ end
47
+ end
@@ -0,0 +1,213 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/estimator'
4
+ require 'rumale/model_selection/cross_validation'
5
+
6
+ module Rumale
7
+ module ModelSelection
8
+ # GridSearchCV is a class that performs hyperparameter optimization with grid search method.
9
+ #
10
+ # @example
11
+ # require 'rumale/ensemble'
12
+ # require 'rumale/model_selection/stratified_k_fold'
13
+ # require 'rumale/model_selection/grid_search_cv'
14
+ #
15
+ # rfc = Rumale::Ensemble::RandomForestClassifier.new(random_seed: 1)
16
+ # pg = { n_estimators: [5, 10], max_depth: [3, 5], max_leaf_nodes: [15, 31] }
17
+ # kf = Rumale::ModelSelection::StratifiedKFold.new(n_splits: 5)
18
+ # gs = Rumale::ModelSelection::GridSearchCV.new(estimator: rfc, param_grid: pg, splitter: kf)
19
+ # gs.fit(samples, labels)
20
+ # p gs.cv_results
21
+ # p gs.best_params
22
+ #
23
+ # @example
24
+ # rbf = Rumale::KernelApproximation::RBF.new(random_seed: 1)
25
+ # svc = Rumale::LinearModel::SVC.new(random_seed: 1)
26
+ # pipe = Rumale::Pipeline::Pipeline.new(steps: { rbf: rbf, svc: svc })
27
+ # pg = { rbf__gamma: [32.0, 1.0], rbf__n_components: [4, 128], svc__reg_param: [16.0, 0.1] }
28
+ # kf = Rumale::ModelSelection::StratifiedKFold.new(n_splits: 5)
29
+ # gs = Rumale::ModelSelection::GridSearchCV.new(estimator: pipe, param_grid: pg, splitter: kf)
30
+ # gs.fit(samples, labels)
31
+ # p gs.cv_results
32
+ # p gs.best_params
33
+ #
34
+ class GridSearchCV < ::Rumale::Base::Estimator
35
+ # Return the result of cross validation for each parameter.
36
+ # @return [Hash]
37
+ attr_reader :cv_results
38
+
39
+ # Return the score of the estimator learned with the best parameter.
40
+ # @return [Float]
41
+ attr_reader :best_score
42
+
43
+ # Return the best parameter set.
44
+ # @return [Hash]
45
+ attr_reader :best_params
46
+
47
+ # Return the index of the best parameter.
48
+ # @return [Integer]
49
+ attr_reader :best_index
50
+
51
+ # Return the estimator learned with the best parameter.
52
+ # @return [Estimator]
53
+ attr_reader :best_estimator
54
+
55
+ # Create a new grid search method.
56
+ #
57
+ # @param estimator [Classifier/Regresor] The estimator to be searched for optimal parameters with grid search method.
58
+ # @param param_grid [Array<Hash>] The parameter sets is represented with array of hash that
59
+ # consists of parameter names as keys and array of parameter values as values.
60
+ # @param splitter [Splitter] The splitter that divides dataset to training and testing dataset on cross validation.
61
+ # @param evaluator [Evaluator] The evaluator that calculates score of estimator results on cross validation.
62
+ # If nil is given, the score method of estimator is used to evaluation.
63
+ # @param greater_is_better [Boolean] The flag that indicates whether the estimator is better as
64
+ # evaluation score is larger.
65
+ def initialize(estimator: nil, param_grid: nil, splitter: nil, evaluator: nil, greater_is_better: true)
66
+ super()
67
+ @params = {
68
+ param_grid: valid_param_grid(param_grid),
69
+ estimator: Marshal.load(Marshal.dump(estimator)),
70
+ splitter: Marshal.load(Marshal.dump(splitter)),
71
+ evaluator: Marshal.load(Marshal.dump(evaluator)),
72
+ greater_is_better: greater_is_better
73
+ }
74
+ end
75
+
76
+ # Fit the model with given training data and all sets of parameters.
77
+ #
78
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
79
+ # @param y [Numo::NArray] (shape: [n_samples, n_outputs]) The target values or labels to be used for fitting the model.
80
+ # @return [GridSearchCV] The learned estimator with grid search.
81
+ def fit(x, y)
82
+ init_attrs
83
+
84
+ param_combinations.each do |prm_set|
85
+ prm_set.each do |prms|
86
+ report = perform_cross_validation(x, y, prms)
87
+ store_cv_result(prms, report)
88
+ end
89
+ end
90
+
91
+ find_best_params
92
+
93
+ @best_estimator = configurated_estimator(@best_params)
94
+ @best_estimator.fit(x, y)
95
+ self
96
+ end
97
+
98
+ # Call the decision_function method of learned estimator with the best parameter.
99
+ #
100
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to compute the scores.
101
+ # @return [Numo::DFloat] (shape: [n_samples]) Confidence score per sample.
102
+ def decision_function(x)
103
+ @best_estimator.decision_function(x)
104
+ end
105
+
106
+ # Call the predict method of learned estimator with the best parameter.
107
+ #
108
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to obtain prediction result.
109
+ # @return [Numo::NArray] Predicted results.
110
+ def predict(x)
111
+ @best_estimator.predict(x)
112
+ end
113
+
114
+ # Call the predict_log_proba method of learned estimator with the best parameter.
115
+ #
116
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the log-probailities.
117
+ # @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted log-probability of each class per sample.
118
+ def predict_log_proba(x)
119
+ @best_estimator.predict_log_proba(x)
120
+ end
121
+
122
+ # Call the predict_proba method of learned estimator with the best parameter.
123
+ #
124
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the probailities.
125
+ # @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted probability of each class per sample.
126
+ def predict_proba(x)
127
+ @best_estimator.predict_proba(x)
128
+ end
129
+
130
+ # Call the score method of learned estimator with the best parameter.
131
+ #
132
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) Testing data.
133
+ # @param y [Numo::NArray] (shape: [n_samples, n_outputs]) True target values or labels for testing data.
134
+ # @return [Float] The score of estimator.
135
+ def score(x, y)
136
+ @best_estimator.score(x, y)
137
+ end
138
+
139
+ private
140
+
141
+ def valid_param_grid(grid)
142
+ raise TypeError, 'Expect class of param_grid to be Hash or Array' unless grid.is_a?(Hash) || grid.is_a?(Array)
143
+
144
+ grid = [grid] if grid.is_a?(Hash)
145
+ grid.each do |h|
146
+ raise TypeError, 'Expect class of elements in param_grid to be Hash' unless h.is_a?(Hash)
147
+ raise TypeError, 'Expect class of parameter values in param_grid to be Array' unless h.values.all?(Array)
148
+ end
149
+ grid
150
+ end
151
+
152
+ def param_combinations
153
+ @param_combinations ||= @params[:param_grid].map do |prm|
154
+ x = prm.sort.to_h.map { |k, v| [k].product(v) }
155
+ x[0].product(*x[1...x.size]).map(&:to_h)
156
+ end
157
+ end
158
+
159
+ def perform_cross_validation(x, y, prms)
160
+ est = configurated_estimator(prms)
161
+ cv = ::Rumale::ModelSelection::CrossValidation.new(estimator: est, splitter: @params[:splitter],
162
+ evaluator: @params[:evaluator], return_train_score: true)
163
+ cv.perform(x, y)
164
+ end
165
+
166
+ def configurated_estimator(prms)
167
+ estimator = Marshal.load(Marshal.dump(@params[:estimator]))
168
+ if pipeline?
169
+ prms.each do |k, v|
170
+ est_name, prm_name = k.to_s.split('__')
171
+ estimator.steps[est_name.to_sym].params[prm_name.to_sym] = v
172
+ end
173
+ else
174
+ prms.each { |k, v| estimator.params[k] = v }
175
+ end
176
+ estimator
177
+ end
178
+
179
+ def init_attrs
180
+ @cv_results = %i[mean_test_score std_test_score
181
+ mean_train_score std_train_score
182
+ mean_fit_time std_fit_time params].to_h { |v| [v, []] }
183
+ @best_score = nil
184
+ @best_params = nil
185
+ @best_index = nil
186
+ @best_estimator = nil
187
+ end
188
+
189
+ def store_cv_result(prms, report)
190
+ test_scores = Numo::DFloat[*report[:test_score]]
191
+ train_scores = Numo::DFloat[*report[:train_score]]
192
+ fit_times = Numo::DFloat[*report[:fit_time]]
193
+ @cv_results[:mean_test_score].push(test_scores.mean)
194
+ @cv_results[:std_test_score].push(test_scores.stddev)
195
+ @cv_results[:mean_train_score].push(train_scores.mean)
196
+ @cv_results[:std_train_score].push(train_scores.stddev)
197
+ @cv_results[:mean_fit_time].push(fit_times.mean)
198
+ @cv_results[:std_fit_time].push(fit_times.stddev)
199
+ @cv_results[:params].push(prms)
200
+ end
201
+
202
+ def find_best_params
203
+ @best_score = @params[:greater_is_better] ? @cv_results[:mean_test_score].max : @cv_results[:mean_test_score].min
204
+ @best_index = @cv_results[:mean_test_score].index(@best_score)
205
+ @best_params = @cv_results[:params][@best_index]
206
+ end
207
+
208
+ def pipeline?
209
+ @params[:estimator].class.name.include?('Rumale::Pipeline')
210
+ end
211
+ end
212
+ end
213
+ end
@@ -0,0 +1,93 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/splitter'
4
+ require 'rumale/preprocessing/label_encoder'
5
+
6
+ module Rumale
7
+ module ModelSelection
8
+ # GroupKFold is a class that generates the set of data indices for K-fold cross-validation.
9
+ # The data points belonging to the same group do not be split into different folds.
10
+ # The number of groups should be greater than or equal to the number of splits.
11
+ #
12
+ # @example
13
+ # require 'rumale/model_selection/group_k_fold'
14
+ #
15
+ # cv = Rumale::ModelSelection::GroupKFold.new(n_splits: 3)
16
+ # x = Numo::DFloat.new(8, 2).rand
17
+ # groups = Numo::Int32[1, 1, 1, 2, 2, 3, 3, 3]
18
+ # cv.split(x, nil, groups).each do |train_ids, test_ids|
19
+ # puts '---'
20
+ # pp train_ids
21
+ # pp test_ids
22
+ # end
23
+ #
24
+ # # ---
25
+ # # [0, 1, 2, 3, 4]
26
+ # # [5, 6, 7]
27
+ # # ---
28
+ # # [3, 4, 5, 6, 7]
29
+ # # [0, 1, 2]
30
+ # # ---
31
+ # # [0, 1, 2, 5, 6, 7]
32
+ # # [3, 4]
33
+ #
34
+ class GroupKFold
35
+ include ::Rumale::Base::Splitter
36
+
37
+ # Return the number of folds.
38
+ # @return [Integer]
39
+ attr_reader :n_splits
40
+
41
+ # Create a new data splitter for grouped K-fold cross validation.
42
+ #
43
+ # @param n_splits [Integer] The number of folds.
44
+ def initialize(n_splits: 5)
45
+ @n_splits = n_splits
46
+ end
47
+
48
+ # Generate data indices for grouped K-fold cross validation.
49
+ #
50
+ # @overload split(x, y, groups) -> Array
51
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features])
52
+ # The dataset to be used to generate data indices for grouped K-fold cross validation.
53
+ # @param y [Numo::Int32] (shape: [n_samples])
54
+ # This argument exists to unify the interface between the K-fold methods, it is not used in the method.
55
+ # @param groups [Numo::Int32] (shape: [n_samples])
56
+ # The group labels to be used to generate data indices for grouped K-fold cross validation.
57
+ # @return [Array] The set of data indices for constructing the training and testing dataset in each fold.
58
+ def split(x, _y, groups)
59
+ encoder = ::Rumale::Preprocessing::LabelEncoder.new
60
+ groups = encoder.fit_transform(groups)
61
+ n_groups = encoder.classes.size
62
+
63
+ if n_groups < @n_splits
64
+ raise ArgumentError,
65
+ 'The number of groups should be greater than or equal to the number of splits.'
66
+ end
67
+
68
+ n_samples_per_group = groups.bincount
69
+ group_ids = n_samples_per_group.sort_index.reverse
70
+ n_samples_per_group = n_samples_per_group[group_ids]
71
+
72
+ n_samples_per_fold = Numo::Int32.zeros(@n_splits)
73
+ group_to_fold = Numo::Int32.zeros(n_groups)
74
+
75
+ n_samples_per_group.each_with_index do |weight, id|
76
+ min_sample_fold_id = n_samples_per_fold.min_index
77
+ n_samples_per_fold[min_sample_fold_id] += weight
78
+ group_to_fold[group_ids[id]] = min_sample_fold_id
79
+ end
80
+
81
+ n_samples = x.shape[0]
82
+ sample_ids = Array(0...n_samples)
83
+ fold_ids = group_to_fold[groups]
84
+
85
+ Array.new(@n_splits) do |fid|
86
+ test_ids = fold_ids.eq(fid).where.to_a
87
+ train_ids = sample_ids - test_ids
88
+ [train_ids, test_ids]
89
+ end
90
+ end
91
+ end
92
+ end
93
+ end
@@ -0,0 +1,108 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/splitter'
4
+
5
+ module Rumale
6
+ module ModelSelection
7
+ # GroupShuffleSplit is a class that generates the set of data indices
8
+ # for random permutation cross-validation by randomly selecting group labels.
9
+ #
10
+ # @example
11
+ # require 'rumale/model_selection/group_shuffle_split'
12
+ #
13
+ # cv = Rumale::ModelSelection::GroupShuffleSplit.new(n_splits: 2, test_size: 0.2, random_seed: 1)
14
+ # x = Numo::DFloat.new(8, 2).rand
15
+ # groups = Numo::Int32[1, 1, 1, 2, 2, 3, 3, 3]
16
+ # cv.split(x, nil, groups).each do |train_ids, test_ids|
17
+ # puts '---'
18
+ # pp train_ids
19
+ # pp test_ids
20
+ # end
21
+ #
22
+ # # ---
23
+ # # [0, 1, 2, 5, 6, 7]
24
+ # # [3, 4]
25
+ # # ---
26
+ # # [3, 4, 5, 6, 7]
27
+ # # [0, 1, 2]
28
+ #
29
+ class GroupShuffleSplit
30
+ include Rumale::Base::Splitter
31
+
32
+ # Return the number of folds.
33
+ # @return [Integer]
34
+ attr_reader :n_splits
35
+
36
+ # Return the random generator for shuffling the dataset.
37
+ # @return [Random]
38
+ attr_reader :rng
39
+
40
+ # Create a new data splitter for random permutation cross validation with given group labels.
41
+ #
42
+ # @param n_splits [Integer] The number of folds.
43
+ # @param test_size [Float] The ratio of number of groups for test data.
44
+ # @param train_size [Float/Nil] The ratio of number of groups for train data.
45
+ # @param random_seed [Integer] The seed value using to initialize the random generator.
46
+ def initialize(n_splits: 5, test_size: 0.2, train_size: nil, random_seed: nil)
47
+ @n_splits = n_splits
48
+ @test_size = test_size
49
+ @train_size = train_size
50
+ @random_seed = random_seed
51
+ @random_seed ||= srand
52
+ @rng = Random.new(@random_seed)
53
+ end
54
+
55
+ # Generate train and test data indices by randomly selecting group labels.
56
+ #
57
+ # @overload split(x, y, groups) -> Array
58
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features])
59
+ # The dataset to be used to generate data indices for random permutation cross validation.
60
+ # @param y [Numo::Int32] (shape: [n_samples])
61
+ # This argument exists to unify the interface between the K-fold methods, it is not used in the method.
62
+ # @param groups [Numo::Int32] (shape: [n_samples])
63
+ # The group labels to be used to generate data indices for random permutation cross validation.
64
+ # @return [Array] The set of data indices for constructing the training and testing dataset in each fold.
65
+ def split(_x, _y, groups)
66
+ classes = groups.to_a.uniq.sort
67
+ n_groups = classes.size
68
+ n_test_groups = (@test_size * n_groups).ceil.to_i
69
+ n_train_groups = @train_size.nil? ? n_groups - n_test_groups : (@train_size * n_groups).floor.to_i
70
+
71
+ unless n_test_groups.between?(1, n_groups)
72
+ raise RangeError,
73
+ 'The number of groups in test split must be not less than 1 and not more than the number of groups.'
74
+ end
75
+ unless n_train_groups.between?(1, n_groups)
76
+ raise RangeError,
77
+ 'The number of groups in train split must be not less than 1 and not more than the number of groups.'
78
+ end
79
+ if (n_test_groups + n_train_groups) > n_groups
80
+ raise RangeError,
81
+ 'The total number of groups in test split and train split must be not more than the number of groups.'
82
+ end
83
+
84
+ sub_rng = @rng.dup
85
+
86
+ Array.new(@n_splits) do
87
+ test_group_ids = classes.sample(n_test_groups, random: sub_rng)
88
+ train_group_ids = if @train_size.nil?
89
+ classes - test_group_ids
90
+ else
91
+ (classes - test_group_ids).sample(n_train_groups, random: sub_rng)
92
+ end
93
+ test_ids = in1d(groups, test_group_ids).where.to_a
94
+ train_ids = in1d(groups, train_group_ids).where.to_a
95
+ [train_ids, test_ids]
96
+ end
97
+ end
98
+
99
+ private
100
+
101
+ def in1d(a, b)
102
+ res = Numo::Bit.zeros(a.shape[0])
103
+ b.each { |v| res |= a.eq(v) }
104
+ res
105
+ end
106
+ end
107
+ end
108
+ end
@@ -0,0 +1,78 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/splitter'
4
+
5
+ module Rumale
6
+ # This module consists of the classes for model validation techniques.
7
+ module ModelSelection
8
+ # KFold is a class that generates the set of data indices for K-fold cross-validation.
9
+ #
10
+ # @example
11
+ # require 'rumale/model_selection/k_fold'
12
+ #
13
+ # kf = Rumale::ModelSelection::KFold.new(n_splits: 3, shuffle: true, random_seed: 1)
14
+ # kf.split(samples, labels).each do |train_ids, test_ids|
15
+ # train_samples = samples[train_ids, true]
16
+ # test_samples = samples[test_ids, true]
17
+ # ...
18
+ # end
19
+ #
20
+ class KFold
21
+ include ::Rumale::Base::Splitter
22
+
23
+ # Return the number of folds.
24
+ # @return [Integer]
25
+ attr_reader :n_splits
26
+
27
+ # Return the flag indicating whether to shuffle the dataset.
28
+ # @return [Boolean]
29
+ attr_reader :shuffle
30
+
31
+ # Return the random generator for shuffling the dataset.
32
+ # @return [Random]
33
+ attr_reader :rng
34
+
35
+ # Create a new data splitter for K-fold cross validation.
36
+ #
37
+ # @param n_splits [Integer] The number of folds.
38
+ # @param shuffle [Boolean] The flag indicating whether to shuffle the dataset.
39
+ # @param random_seed [Integer] The seed value using to initialize the random generator.
40
+ def initialize(n_splits: 3, shuffle: false, random_seed: nil)
41
+ @n_splits = n_splits
42
+ @shuffle = shuffle
43
+ @random_seed = random_seed
44
+ @random_seed ||= srand
45
+ @rng = Random.new(@random_seed)
46
+ end
47
+
48
+ # Generate data indices for K-fold cross validation.
49
+ #
50
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features])
51
+ # The dataset to be used to generate data indices for K-fold cross validation.
52
+ # @return [Array] The set of data indices for constructing the training and testing dataset in each fold.
53
+ def split(x, _y = nil)
54
+ # Initialize and check some variables.
55
+ n_samples, = x.shape
56
+ unless @n_splits.between?(2, n_samples)
57
+ raise ArgumentError,
58
+ 'The value of n_splits must be not less than 2 and not more than the number of samples.'
59
+ end
60
+ sub_rng = @rng.dup
61
+ # Splits dataset ids to each fold.
62
+ dataset_ids = Array(0...n_samples)
63
+ dataset_ids.shuffle!(random: sub_rng) if @shuffle
64
+ fold_sets = Array.new(@n_splits) do |n|
65
+ n_fold_samples = n_samples / @n_splits
66
+ n_fold_samples += 1 if n < n_samples % @n_splits
67
+ dataset_ids.shift(n_fold_samples)
68
+ end
69
+ # Returns array consisting of the training and testing ids for each fold.
70
+ Array.new(@n_splits) do |n|
71
+ train_ids = fold_sets.select.with_index { |_, id| id != n }.flatten
72
+ test_ids = fold_sets[n]
73
+ [train_ids, test_ids]
74
+ end
75
+ end
76
+ end
77
+ end
78
+ end
@@ -0,0 +1,86 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/splitter'
4
+
5
+ module Rumale
6
+ module ModelSelection
7
+ # ShuffleSplit is a class that generates the set of data indices for random permutation cross-validation.
8
+ #
9
+ # @example
10
+ # require 'rumale/model_selection/shuffle_split'
11
+ #
12
+ # ss = Rumale::ModelSelection::ShuffleSplit.new(n_splits: 3, test_size: 0.2, random_seed: 1)
13
+ # ss.split(samples, labels).each do |train_ids, test_ids|
14
+ # train_samples = samples[train_ids, true]
15
+ # test_samples = samples[test_ids, true]
16
+ # ...
17
+ # end
18
+ #
19
+ class ShuffleSplit
20
+ include ::Rumale::Base::Splitter
21
+
22
+ # Return the number of folds.
23
+ # @return [Integer]
24
+ attr_reader :n_splits
25
+
26
+ # Return the random generator for shuffling the dataset.
27
+ # @return [Random]
28
+ attr_reader :rng
29
+
30
+ # Create a new data splitter for random permutation cross validation.
31
+ #
32
+ # @param n_splits [Integer] The number of folds.
33
+ # @param test_size [Float] The ratio of number of samples for test data.
34
+ # @param train_size [Float] The ratio of number of samples for train data.
35
+ # @param random_seed [Integer] The seed value using to initialize the random generator.
36
+ def initialize(n_splits: 3, test_size: 0.1, train_size: nil, random_seed: nil)
37
+ @n_splits = n_splits
38
+ @test_size = test_size
39
+ @train_size = train_size
40
+ @random_seed = random_seed
41
+ @random_seed ||= srand
42
+ @rng = Random.new(@random_seed)
43
+ end
44
+
45
+ # Generate data indices for random permutation cross validation.
46
+ #
47
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features])
48
+ # The dataset to be used to generate data indices for random permutation cross validation.
49
+ # @return [Array] The set of data indices for constructing the training and testing dataset in each fold.
50
+ def split(x, _y = nil)
51
+ # Initialize and check some variables.
52
+ n_samples = x.shape[0]
53
+ n_test_samples = (@test_size * n_samples).ceil.to_i
54
+ n_train_samples = @train_size.nil? ? n_samples - n_test_samples : (@train_size * n_samples).floor.to_i
55
+ unless @n_splits.between?(1, n_samples)
56
+ raise ArgumentError,
57
+ 'The value of n_splits must be not less than 1 and not more than the number of samples.'
58
+ end
59
+ unless n_test_samples.between?(1, n_samples)
60
+ raise RangeError,
61
+ 'The number of samples in test split must be not less than 1 and not more than the number of samples.'
62
+ end
63
+ unless n_train_samples.between?(1, n_samples)
64
+ raise RangeError,
65
+ 'The number of samples in train split must be not less than 1 and not more than the number of samples.'
66
+ end
67
+ if (n_test_samples + n_train_samples) > n_samples
68
+ raise RangeError,
69
+ 'The total number of samples in test split and train split must be not more than the number of samples.'
70
+ end
71
+ sub_rng = @rng.dup
72
+ # Returns array consisting of the training and testing ids for each fold.
73
+ dataset_ids = Array(0...n_samples)
74
+ Array.new(@n_splits) do
75
+ test_ids = dataset_ids.sample(n_test_samples, random: sub_rng)
76
+ train_ids = if @train_size.nil?
77
+ dataset_ids - test_ids
78
+ else
79
+ (dataset_ids - test_ids).sample(n_train_samples, random: sub_rng)
80
+ end
81
+ [train_ids, test_ids]
82
+ end
83
+ end
84
+ end
85
+ end
86
+ end
@@ -0,0 +1,95 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/splitter'
4
+
5
+ module Rumale
6
+ module ModelSelection
7
+ # StratifiedKFold is a class that generates the set of data indices for K-fold cross-validation.
8
+ # The proportion of the number of samples in each class will be almost equal for each fold.
9
+ #
10
+ # @example
11
+ # require 'rumale/model_selection/stratified_k_fold'
12
+ #
13
+ # kf = Rumale::ModelSelection::StratifiedKFold.new(n_splits: 3, shuffle: true, random_seed: 1)
14
+ # kf.split(samples, labels).each do |train_ids, test_ids|
15
+ # train_samples = samples[train_ids, true]
16
+ # test_samples = samples[test_ids, true]
17
+ # ...
18
+ # end
19
+ #
20
+ class StratifiedKFold
21
+ include ::Rumale::Base::Splitter
22
+
23
+ # Return the number of folds.
24
+ # @return [Integer]
25
+ attr_reader :n_splits
26
+
27
+ # Return the flag indicating whether to shuffle the dataset.
28
+ # @return [Boolean]
29
+ attr_reader :shuffle
30
+
31
+ # Return the random generator for shuffling the dataset.
32
+ # @return [Random]
33
+ attr_reader :rng
34
+
35
+ # Create a new data splitter for stratified K-fold cross validation.
36
+ #
37
+ # @param n_splits [Integer] The number of folds.
38
+ # @param shuffle [Boolean] The flag indicating whether to shuffle the dataset.
39
+ # @param random_seed [Integer] The seed value using to initialize the random generator.
40
+ def initialize(n_splits: 3, shuffle: false, random_seed: nil)
41
+ @n_splits = n_splits
42
+ @shuffle = shuffle
43
+ @random_seed = random_seed
44
+ @random_seed ||= srand
45
+ @rng = Random.new(@random_seed)
46
+ end
47
+
48
+ # Generate data indices for stratified K-fold cross validation.
49
+ #
50
+ # @overload split(x, y) -> Array
51
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features])
52
+ # The dataset to be used to generate data indices for stratified K-fold cross validation.
53
+ # This argument exists to unify the interface between the K-fold methods, it is not used in the method.
54
+ # @param y [Numo::Int32] (shape: [n_samples])
55
+ # The labels to be used to generate data indices for stratified K-fold cross validation.
56
+ # @return [Array] The set of data indices for constructing the training and testing dataset in each fold.
57
+ def split(_x, y)
58
+ # Check the number of samples in each class.
59
+ unless valid_n_splits?(y)
60
+ raise ArgumentError,
61
+ 'The value of n_splits must be not less than 2 and not more than the number of samples in each class.'
62
+ end
63
+ # Splits dataset ids of each class to each fold.
64
+ sub_rng = @rng.dup
65
+ fold_sets_each_class = y.to_a.uniq.map { |label| fold_sets(y, label, sub_rng) }
66
+ # Returns array consisting of the training and testing ids for each fold.
67
+ Array.new(@n_splits) { |fold_id| train_test_sets(fold_sets_each_class, fold_id) }
68
+ end
69
+
70
+ private
71
+
72
+ def valid_n_splits?(y)
73
+ y.to_a.uniq.map { |label| y.eq(label).where.size }.all? { |n_samples| @n_splits.between?(2, n_samples) }
74
+ end
75
+
76
+ def fold_sets(y, label, sub_rng)
77
+ sample_ids = y.eq(label).where.to_a
78
+ sample_ids.shuffle!(random: sub_rng) if @shuffle
79
+ n_samples = sample_ids.size
80
+ Array.new(@n_splits) do |n|
81
+ n_fold_samples = n_samples / @n_splits
82
+ n_fold_samples += 1 if n < n_samples % @n_splits
83
+ sample_ids.shift(n_fold_samples)
84
+ end
85
+ end
86
+
87
+ def train_test_sets(fold_sets_each_class, fold_id)
88
+ train_test_sets_each_class = fold_sets_each_class.map do |folds|
89
+ folds.partition.with_index { |_, id| id != fold_id }.map(&:flatten)
90
+ end
91
+ train_test_sets_each_class.transpose.map(&:flatten)
92
+ end
93
+ end
94
+ end
95
+ end
@@ -0,0 +1,115 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/splitter'
4
+
5
+ module Rumale
6
+ module ModelSelection
7
+ # StratifiedShuffleSplit is a class that generates the set of data indices for random permutation cross-validation.
8
+ # The proportion of the number of samples in each class will be almost equal for each fold.
9
+ #
10
+ # @example
11
+ # require 'rumale/model_selection/stratified_shuffle_split'
12
+ #
13
+ # ss = Rumale::ModelSelection::StratifiedShuffleSplit.new(n_splits: 3, test_size: 0.2, random_seed: 1)
14
+ # ss.split(samples, labels).each do |train_ids, test_ids|
15
+ # train_samples = samples[train_ids, true]
16
+ # test_samples = samples[test_ids, true]
17
+ # ...
18
+ # end
19
+ #
20
+ class StratifiedShuffleSplit
21
+ include ::Rumale::Base::Splitter
22
+
23
+ # Return the number of folds.
24
+ # @return [Integer]
25
+ attr_reader :n_splits
26
+
27
+ # Return the random generator for shuffling the dataset.
28
+ # @return [Random]
29
+ attr_reader :rng
30
+
31
+ # Create a new data splitter for random permutation cross validation.
32
+ #
33
+ # @param n_splits [Integer] The number of folds.
34
+ # @param test_size [Float] The ratio of number of samples for test data.
35
+ # @param train_size [Float] The ratio of number of samples for train data.
36
+ # @param random_seed [Integer] The seed value using to initialize the random generator.
37
+ def initialize(n_splits: 3, test_size: 0.1, train_size: nil, random_seed: nil)
38
+ @n_splits = n_splits
39
+ @test_size = test_size
40
+ @train_size = train_size
41
+ @random_seed = random_seed
42
+ @random_seed ||= srand
43
+ @rng = Random.new(@random_seed)
44
+ end
45
+
46
+ # Generate data indices for stratified random permutation cross validation.
47
+ #
48
+ # @overload split(x, y) -> Array
49
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features])
50
+ # The dataset to be used to generate data indices for stratified random permutation cross validation.
51
+ # This argument exists to unify the interface between the K-fold methods, it is not used in the method.
52
+ # @param y [Numo::Int32] (shape: [n_samples])
53
+ # The labels to be used to generate data indices for stratified random permutation cross validation.
54
+ # @return [Array] The set of data indices for constructing the training and testing dataset in each fold.
55
+ def split(_x, y)
56
+ # Initialize and check some variables.
57
+ train_sz = @train_size.nil? ? 1.0 - @test_size : @train_size
58
+ sub_rng = @rng.dup
59
+ # Check the number of samples in each class.
60
+ unless valid_n_splits?(y)
61
+ raise ArgumentError,
62
+ 'The value of n_splits must be not less than 1 and not more than the number of samples in each class.'
63
+ end
64
+ # rubocop:disable Layout/LineLength
65
+ unless enough_data_size_each_class?(y, @test_size, 'test')
66
+ raise RangeError,
67
+ 'The number of samples in test split must be not less than 1 and not more than the number of samples in each class.'
68
+ end
69
+ unless enough_data_size_each_class?(y, train_sz, 'train')
70
+ raise RangeError,
71
+ 'The number of samples in train split must be not less than 1 and not more than the number of samples in each class.'
72
+ end
73
+ unless enough_data_size_each_class?(y, train_sz + @test_size, 'train')
74
+ raise RangeError,
75
+ 'The total number of samples in test split and train split must be not more than the number of samples in each class.'
76
+ end
77
+ # rubocop:enable Layout/LineLength
78
+ # Returns array consisting of the training and testing ids for each fold.
79
+ sample_ids_each_class = y.to_a.uniq.map { |label| y.eq(label).where.to_a }
80
+ Array.new(@n_splits) do
81
+ train_ids = []
82
+ test_ids = []
83
+ sample_ids_each_class.each do |sample_ids|
84
+ n_samples = sample_ids.size
85
+ n_test_samples = (@test_size * n_samples).ceil.to_i
86
+ test_ids += sample_ids.sample(n_test_samples, random: sub_rng)
87
+ train_ids += if @train_size.nil?
88
+ sample_ids - test_ids
89
+ else
90
+ n_train_samples = (train_sz * n_samples).floor.to_i
91
+ (sample_ids - test_ids).sample(n_train_samples, random: sub_rng)
92
+ end
93
+ end
94
+ [train_ids, test_ids]
95
+ end
96
+ end
97
+
98
+ private
99
+
100
+ def valid_n_splits?(y)
101
+ y.to_a.uniq.map { |label| y.eq(label).where.size }.all? { |n_samples| @n_splits.between?(1, n_samples) }
102
+ end
103
+
104
+ def enough_data_size_each_class?(y, data_size, data_type)
105
+ y.to_a.uniq.map { |label| y.eq(label).where.size }.all? do |n_samples|
106
+ if data_type == 'test'
107
+ (data_size * n_samples).ceil.to_i.between?(1, n_samples)
108
+ else
109
+ (data_size * n_samples).floor.to_i.between?(1, n_samples)
110
+ end
111
+ end
112
+ end
113
+ end
114
+ end
115
+ end
@@ -0,0 +1,89 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/splitter'
4
+
5
+ module Rumale
6
+ module ModelSelection
7
+ # TimeSeriesSplit is a class that generates the set of data indices for time series cross-validation.
8
+ # It is assumed that the dataset given are already ordered by time information.
9
+ #
10
+ # @example
11
+ # require 'rumale/model_selection/time_series_split'
12
+ #
13
+ # cv = Rumale::ModelSelection::TimeSeriesSplit.new(n_splits: 5)
14
+ # x = Numo::DFloat.new(6, 2).rand
15
+ # cv.split(x, nil).each do |train_ids, test_ids|
16
+ # puts '---'
17
+ # pp train_ids
18
+ # pp test_ids
19
+ # end
20
+ #
21
+ # # ---
22
+ # # [0]
23
+ # # [1]
24
+ # # ---
25
+ # # [0, 1]
26
+ # # [2]
27
+ # # ---
28
+ # # [0, 1, 2]
29
+ # # [3]
30
+ # # ---
31
+ # # [0, 1, 2, 3]
32
+ # # [4]
33
+ # # ---
34
+ # # [0, 1, 2, 3, 4]
35
+ # # [5]
36
+ #
37
+ class TimeSeriesSplit
38
+ include ::Rumale::Base::Splitter
39
+
40
+ # Return the number of splits.
41
+ # @return [Integer]
42
+ attr_reader :n_splits
43
+
44
+ # Return the maximum number of training samples in a split.
45
+ # @return [Integer/Nil]
46
+ attr_reader :max_train_size
47
+
48
+ # Create a new data splitter for time series cross-validation.
49
+ #
50
+ # @param n_splits [Integer] The number of splits.
51
+ # @param max_train_size [Integer/Nil] The maximum number of training samples in a split.
52
+ def initialize(n_splits: 5, max_train_size: nil)
53
+ @n_splits = n_splits
54
+ @max_train_size = max_train_size
55
+ end
56
+
57
+ # Generate data indices for time series cross-validation.
58
+ #
59
+ # @overload split(x, y) -> Array
60
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features])
61
+ # The dataset to be used to generate data indices for time series cross-validation.
62
+ # It is expected that the data will be ordered by time information.
63
+ # @param y [Numo::Int32] (shape: [n_samples])
64
+ # This argument exists to unify the interface between the K-fold methods, it is not used in the method.
65
+ # @return [Array] The set of data indices for constructing the training and testing dataset in each fold.
66
+ def split(x, _y)
67
+ n_samples = x.shape[0]
68
+ unless (@n_splits + 1).between?(2, n_samples)
69
+ raise ArgumentError,
70
+ 'The number of folds (n_splits + 1) must be not less than 2 and not more than the number of samples.'
71
+ end
72
+
73
+ test_size = n_samples / (@n_splits + 1)
74
+ offset = test_size + n_samples % (@n_splits + 1)
75
+
76
+ Array.new(@n_splits) do |n|
77
+ start = offset * (n + 1)
78
+ train_ids = if !@max_train_size.nil? && @max_train_size < test_size
79
+ Array((start - @max_train_size)...start)
80
+ else
81
+ Array(0...start)
82
+ end
83
+ test_ids = Array(start...(start + test_size))
84
+ [train_ids, test_ids]
85
+ end
86
+ end
87
+ end
88
+ end
89
+ end
@@ -0,0 +1,8 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Rumale
4
+ module ModelSelection
5
+ # @!visibility private
6
+ VERSION = '0.24.0'
7
+ end
8
+ end
@@ -0,0 +1,16 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'numo/narray'
4
+
5
+ require_relative 'model_selection/version'
6
+
7
+ require_relative 'model_selection/cross_validation'
8
+ require_relative 'model_selection/function'
9
+ require_relative 'model_selection/grid_search_cv'
10
+ require_relative 'model_selection/group_k_fold'
11
+ require_relative 'model_selection/group_shuffle_split'
12
+ require_relative 'model_selection/k_fold'
13
+ require_relative 'model_selection/shuffle_split'
14
+ require_relative 'model_selection/stratified_k_fold'
15
+ require_relative 'model_selection/stratified_shuffle_split'
16
+ require_relative 'model_selection/time_series_split'
metadata ADDED
@@ -0,0 +1,121 @@
1
+ --- !ruby/object:Gem::Specification
2
+ name: rumale-model_selection
3
+ version: !ruby/object:Gem::Version
4
+ version: 0.24.0
5
+ platform: ruby
6
+ authors:
7
+ - yoshoku
8
+ autorequire:
9
+ bindir: exe
10
+ cert_chain: []
11
+ date: 2022-12-31 00:00:00.000000000 Z
12
+ dependencies:
13
+ - !ruby/object:Gem::Dependency
14
+ name: numo-narray
15
+ requirement: !ruby/object:Gem::Requirement
16
+ requirements:
17
+ - - ">="
18
+ - !ruby/object:Gem::Version
19
+ version: 0.9.1
20
+ type: :runtime
21
+ prerelease: false
22
+ version_requirements: !ruby/object:Gem::Requirement
23
+ requirements:
24
+ - - ">="
25
+ - !ruby/object:Gem::Version
26
+ version: 0.9.1
27
+ - !ruby/object:Gem::Dependency
28
+ name: rumale-core
29
+ requirement: !ruby/object:Gem::Requirement
30
+ requirements:
31
+ - - "~>"
32
+ - !ruby/object:Gem::Version
33
+ version: 0.24.0
34
+ type: :runtime
35
+ prerelease: false
36
+ version_requirements: !ruby/object:Gem::Requirement
37
+ requirements:
38
+ - - "~>"
39
+ - !ruby/object:Gem::Version
40
+ version: 0.24.0
41
+ - !ruby/object:Gem::Dependency
42
+ name: rumale-evaluation_measure
43
+ requirement: !ruby/object:Gem::Requirement
44
+ requirements:
45
+ - - "~>"
46
+ - !ruby/object:Gem::Version
47
+ version: 0.24.0
48
+ type: :runtime
49
+ prerelease: false
50
+ version_requirements: !ruby/object:Gem::Requirement
51
+ requirements:
52
+ - - "~>"
53
+ - !ruby/object:Gem::Version
54
+ version: 0.24.0
55
+ - !ruby/object:Gem::Dependency
56
+ name: rumale-preprocessing
57
+ requirement: !ruby/object:Gem::Requirement
58
+ requirements:
59
+ - - "~>"
60
+ - !ruby/object:Gem::Version
61
+ version: 0.24.0
62
+ type: :runtime
63
+ prerelease: false
64
+ version_requirements: !ruby/object:Gem::Requirement
65
+ requirements:
66
+ - - "~>"
67
+ - !ruby/object:Gem::Version
68
+ version: 0.24.0
69
+ description: |
70
+ Rumale::ModelSelection provides model validation techniques,
71
+ such as k-fold cross-validation, time series cross-validation, and grid search,
72
+ with Rumale interface.
73
+ email:
74
+ - yoshoku@outlook.com
75
+ executables: []
76
+ extensions: []
77
+ extra_rdoc_files: []
78
+ files:
79
+ - LICENSE.txt
80
+ - README.md
81
+ - lib/rumale/model_selection.rb
82
+ - lib/rumale/model_selection/cross_validation.rb
83
+ - lib/rumale/model_selection/function.rb
84
+ - lib/rumale/model_selection/grid_search_cv.rb
85
+ - lib/rumale/model_selection/group_k_fold.rb
86
+ - lib/rumale/model_selection/group_shuffle_split.rb
87
+ - lib/rumale/model_selection/k_fold.rb
88
+ - lib/rumale/model_selection/shuffle_split.rb
89
+ - lib/rumale/model_selection/stratified_k_fold.rb
90
+ - lib/rumale/model_selection/stratified_shuffle_split.rb
91
+ - lib/rumale/model_selection/time_series_split.rb
92
+ - lib/rumale/model_selection/version.rb
93
+ homepage: https://github.com/yoshoku/rumale
94
+ licenses:
95
+ - BSD-3-Clause
96
+ metadata:
97
+ homepage_uri: https://github.com/yoshoku/rumale
98
+ source_code_uri: https://github.com/yoshoku/rumale/tree/main/rumale-model_selection
99
+ changelog_uri: https://github.com/yoshoku/rumale/blob/main/CHANGELOG.md
100
+ documentation_uri: https://yoshoku.github.io/rumale/doc/
101
+ rubygems_mfa_required: 'true'
102
+ post_install_message:
103
+ rdoc_options: []
104
+ require_paths:
105
+ - lib
106
+ required_ruby_version: !ruby/object:Gem::Requirement
107
+ requirements:
108
+ - - ">="
109
+ - !ruby/object:Gem::Version
110
+ version: '0'
111
+ required_rubygems_version: !ruby/object:Gem::Requirement
112
+ requirements:
113
+ - - ">="
114
+ - !ruby/object:Gem::Version
115
+ version: '0'
116
+ requirements: []
117
+ rubygems_version: 3.3.26
118
+ signing_key:
119
+ specification_version: 4
120
+ summary: Rumale::ModelSelection provides model validation techniques with Rumale interface.
121
+ test_files: []