rumale 0.20.0 → 0.22.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.github/workflows/build.yml +23 -0
- data/.rubocop.yml +15 -95
- data/CHANGELOG.md +28 -0
- data/Gemfile +4 -2
- data/README.md +5 -2
- data/lib/rumale.rb +3 -0
- data/lib/rumale/clustering/hdbscan.rb +2 -2
- data/lib/rumale/clustering/snn.rb +1 -1
- data/lib/rumale/dataset.rb +1 -1
- data/lib/rumale/decomposition/nmf.rb +2 -2
- data/lib/rumale/ensemble/random_forest_classifier.rb +1 -1
- data/lib/rumale/ensemble/random_forest_regressor.rb +1 -1
- data/lib/rumale/evaluation_measure/roc_auc.rb +3 -0
- data/lib/rumale/feature_extraction/feature_hasher.rb +1 -1
- data/lib/rumale/feature_extraction/hash_vectorizer.rb +1 -1
- data/lib/rumale/linear_model/base_sgd.rb +1 -1
- data/lib/rumale/linear_model/elastic_net.rb +2 -2
- data/lib/rumale/linear_model/lasso.rb +2 -2
- data/lib/rumale/linear_model/linear_regression.rb +2 -2
- data/lib/rumale/linear_model/logistic_regression.rb +123 -35
- data/lib/rumale/linear_model/ridge.rb +2 -2
- data/lib/rumale/linear_model/svc.rb +2 -2
- data/lib/rumale/linear_model/svr.rb +2 -2
- data/lib/rumale/manifold/tsne.rb +1 -1
- data/lib/rumale/metric_learning/neighbourhood_component_analysis.rb +13 -45
- data/lib/rumale/model_selection/group_k_fold.rb +93 -0
- data/lib/rumale/model_selection/group_shuffle_split.rb +115 -0
- data/lib/rumale/model_selection/shuffle_split.rb +4 -4
- data/lib/rumale/model_selection/stratified_k_fold.rb +1 -1
- data/lib/rumale/model_selection/stratified_shuffle_split.rb +13 -9
- data/lib/rumale/model_selection/time_series_split.rb +91 -0
- data/lib/rumale/pipeline/pipeline.rb +1 -1
- data/lib/rumale/probabilistic_output.rb +1 -1
- data/lib/rumale/tree/base_decision_tree.rb +2 -9
- data/lib/rumale/tree/gradient_tree_regressor.rb +3 -10
- data/lib/rumale/version.rb +1 -1
- data/rumale.gemspec +1 -0
- metadata +21 -4
- data/.coveralls.yml +0 -1
@@ -10,7 +10,7 @@ module Rumale
|
|
10
10
|
#
|
11
11
|
# @example
|
12
12
|
# estimator =
|
13
|
-
# Rumale::LinearModel::Ridge.new(reg_param: 0.1, max_iter:
|
13
|
+
# Rumale::LinearModel::Ridge.new(reg_param: 0.1, max_iter: 1000, batch_size: 20, random_seed: 1)
|
14
14
|
# estimator.fit(training_samples, traininig_values)
|
15
15
|
# results = estimator.predict(testing_samples)
|
16
16
|
#
|
@@ -70,7 +70,7 @@ module Rumale
|
|
70
70
|
# @param random_seed [Integer] The seed value using to initialize the random generator.
|
71
71
|
def initialize(learning_rate: 0.01, decay: nil, momentum: 0.9,
|
72
72
|
reg_param: 1.0, fit_bias: true, bias_scale: 1.0,
|
73
|
-
max_iter:
|
73
|
+
max_iter: 1000, batch_size: 50, tol: 1e-4,
|
74
74
|
solver: 'auto',
|
75
75
|
n_jobs: nil, verbose: false, random_seed: nil)
|
76
76
|
check_params_numeric(learning_rate: learning_rate, momentum: momentum,
|
@@ -17,7 +17,7 @@ module Rumale
|
|
17
17
|
#
|
18
18
|
# @example
|
19
19
|
# estimator =
|
20
|
-
# Rumale::LinearModel::SVC.new(reg_param: 1.0, max_iter:
|
20
|
+
# Rumale::LinearModel::SVC.new(reg_param: 1.0, max_iter: 1000, batch_size: 50, random_seed: 1)
|
21
21
|
# estimator.fit(training_samples, traininig_labels)
|
22
22
|
# results = estimator.predict(testing_samples)
|
23
23
|
#
|
@@ -74,7 +74,7 @@ module Rumale
|
|
74
74
|
def initialize(learning_rate: 0.01, decay: nil, momentum: 0.9,
|
75
75
|
penalty: 'l2', reg_param: 1.0, l1_ratio: 0.5,
|
76
76
|
fit_bias: true, bias_scale: 1.0,
|
77
|
-
max_iter:
|
77
|
+
max_iter: 1000, batch_size: 50, tol: 1e-4,
|
78
78
|
probability: false,
|
79
79
|
n_jobs: nil, verbose: false, random_seed: nil)
|
80
80
|
check_params_numeric(learning_rate: learning_rate, momentum: momentum,
|
@@ -14,7 +14,7 @@ module Rumale
|
|
14
14
|
#
|
15
15
|
# @example
|
16
16
|
# estimator =
|
17
|
-
# Rumale::LinearModel::SVR.new(reg_param: 1.0, epsilon: 0.1, max_iter:
|
17
|
+
# Rumale::LinearModel::SVR.new(reg_param: 1.0, epsilon: 0.1, max_iter: 1000, batch_size: 50, random_seed: 1)
|
18
18
|
# estimator.fit(training_samples, traininig_target_values)
|
19
19
|
# results = estimator.predict(testing_samples)
|
20
20
|
#
|
@@ -68,7 +68,7 @@ module Rumale
|
|
68
68
|
penalty: 'l2', reg_param: 1.0, l1_ratio: 0.5,
|
69
69
|
fit_bias: true, bias_scale: 1.0,
|
70
70
|
epsilon: 0.1,
|
71
|
-
max_iter:
|
71
|
+
max_iter: 1000, batch_size: 50, tol: 1e-4,
|
72
72
|
n_jobs: nil, verbose: false, random_seed: nil)
|
73
73
|
check_params_numeric(learning_rate: learning_rate, momentum: momentum,
|
74
74
|
reg_param: reg_param, bias_scale: bias_scale, epsilon: epsilon,
|
data/lib/rumale/manifold/tsne.rb
CHANGED
@@ -102,7 +102,7 @@ module Rumale
|
|
102
102
|
break if terminate?(hi_prob_mat, lo_prob_mat)
|
103
103
|
|
104
104
|
a = hi_prob_mat * lo_prob_mat
|
105
|
-
b = lo_prob_mat
|
105
|
+
b = lo_prob_mat**2
|
106
106
|
y = (b.dot(one_vec) * y + (a - b).dot(y)) / a.dot(one_vec)
|
107
107
|
lo_prob_mat = t_distributed_probability_matrix(y)
|
108
108
|
@n_iter = t + 1
|
@@ -2,13 +2,13 @@
|
|
2
2
|
|
3
3
|
require 'rumale/base/base_estimator'
|
4
4
|
require 'rumale/base/transformer'
|
5
|
+
require 'lbfgsb'
|
5
6
|
|
6
7
|
module Rumale
|
7
8
|
module MetricLearning
|
8
9
|
# NeighbourhoodComponentAnalysis is a class that implements Neighbourhood Component Analysis.
|
9
10
|
#
|
10
11
|
# @example
|
11
|
-
# require 'mopti'
|
12
12
|
# require 'rumale'
|
13
13
|
#
|
14
14
|
# transformer = Rumale::MetricLearning::NeighbourhoodComponentAnalysis.new
|
@@ -39,7 +39,9 @@ module Rumale
|
|
39
39
|
# @param init [String] The initialization method for components ('random' or 'pca').
|
40
40
|
# @param max_iter [Integer] The maximum number of iterations.
|
41
41
|
# @param tol [Float] The tolerance of termination criterion.
|
42
|
+
# This value is given as tol / Lbfgsb::DBL_EPSILON to the factr argument of Lbfgsb.minimize method.
|
42
43
|
# @param verbose [Boolean] The flag indicating whether to output loss during iteration.
|
44
|
+
# If true is given, 'iterate.dat' file is generated by lbfgsb.rb.
|
43
45
|
# @param random_seed [Integer] The seed value using to initialize the random generator.
|
44
46
|
def initialize(n_components: nil, init: 'random', max_iter: 100, tol: 1e-6, verbose: false, random_seed: nil)
|
45
47
|
check_params_numeric_or_nil(n_components: n_components, random_seed: random_seed)
|
@@ -65,8 +67,6 @@ module Rumale
|
|
65
67
|
# @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
|
66
68
|
# @return [NeighbourhoodComponentAnalysis] The learned classifier itself.
|
67
69
|
def fit(x, y)
|
68
|
-
raise 'NeighbourhoodComponentAnalysis#fit requires Mopti but that is not loaded.' unless enable_mopti?
|
69
|
-
|
70
70
|
x = check_convert_sample_array(x)
|
71
71
|
y = check_convert_label_array(y)
|
72
72
|
check_sample_label_size(x, y)
|
@@ -102,17 +102,9 @@ module Rumale
|
|
102
102
|
|
103
103
|
private
|
104
104
|
|
105
|
-
def enable_mopti?
|
106
|
-
if defined?(Mopti).nil?
|
107
|
-
warn('NeighbourhoodComponentAnalysis#fit requires Mopti but that is not loaded. You should intall and load mopti gem in advance.')
|
108
|
-
return false
|
109
|
-
end
|
110
|
-
true
|
111
|
-
end
|
112
|
-
|
113
105
|
def init_components(x, n_features, n_components)
|
114
106
|
if @params[:init] == 'pca'
|
115
|
-
pca = Rumale::Decomposition::PCA.new(n_components: n_components
|
107
|
+
pca = Rumale::Decomposition::PCA.new(n_components: n_components)
|
116
108
|
pca.fit(x).components.flatten.dup
|
117
109
|
else
|
118
110
|
Rumale::Utils.rand_normal([n_features, n_components], @rng.dup).flatten.dup
|
@@ -127,28 +119,18 @@ module Rumale
|
|
127
119
|
res[:x] = comp_init
|
128
120
|
res[:n_iter] = 0
|
129
121
|
# perform optimization.
|
130
|
-
|
131
|
-
|
132
|
-
x_init: comp_init, args: [x, y],
|
133
|
-
|
122
|
+
verbose = @params[:verbose] ? 1 : -1
|
123
|
+
res = Lbfgsb.minimize(
|
124
|
+
fnc: method(:nca_fnc), jcb: true, x_init: comp_init, args: [x, y],
|
125
|
+
maxiter: @params[:max_iter], factr: @params[:tol] / Lbfgsb::DBL_EPSILON, verbose: verbose
|
134
126
|
)
|
135
|
-
fold = 0.0
|
136
|
-
dold = 0.0
|
137
|
-
optimizer.each do |prm|
|
138
|
-
res = prm
|
139
|
-
puts "[NeighbourhoodComponentAnalysis] The value of objective function after #{res[:n_iter]} epochs: #{x.shape[0] - res[:fnc]}" if @params[:verbose]
|
140
|
-
break if (fold - res[:fnc]).abs <= @params[:tol] && (dold - res[:jcb]).abs <= @params[:tol]
|
141
|
-
|
142
|
-
fold = res[:fnc]
|
143
|
-
dold = res[:jcb]
|
144
|
-
end
|
145
127
|
# return the results.
|
146
128
|
n_iter = res[:n_iter]
|
147
129
|
comps = n_components == 1 ? res[:x].dup : res[:x].reshape(n_components, n_features)
|
148
130
|
[comps, n_iter]
|
149
131
|
end
|
150
132
|
|
151
|
-
def
|
133
|
+
def nca_fnc(w, x, y)
|
152
134
|
# initialize some variables.
|
153
135
|
n_samples, n_features = x.shape
|
154
136
|
n_components = w.size / n_features
|
@@ -157,32 +139,18 @@ module Rumale
|
|
157
139
|
z = x.dot(w.transpose)
|
158
140
|
# calculate probability matrix.
|
159
141
|
prob_mat = probability_matrix(z)
|
160
|
-
# calculate loss.
|
142
|
+
# calculate loss and gradient.
|
161
143
|
# NOTE:
|
162
144
|
# NCA attempts to maximize its objective function.
|
163
145
|
# For the minization algorithm, the objective function value is subtracted from the maixmum value (n_samples).
|
164
146
|
mask_mat = y.expand_dims(1).eq(y)
|
165
147
|
masked_prob_mat = prob_mat * mask_mat
|
166
|
-
n_samples - masked_prob_mat.sum
|
167
|
-
end
|
168
|
-
|
169
|
-
def nca_dloss(w, x, y)
|
170
|
-
# initialize some variables.
|
171
|
-
n_features = x.shape[1]
|
172
|
-
n_components = w.size / n_features
|
173
|
-
# projection.
|
174
|
-
w = w.reshape(n_components, n_features)
|
175
|
-
z = x.dot(w.transpose)
|
176
|
-
# calculate probability matrix.
|
177
|
-
prob_mat = probability_matrix(z)
|
178
|
-
# calculate gradient.
|
179
|
-
mask_mat = y.expand_dims(1).eq(y)
|
180
|
-
masked_prob_mat = prob_mat * mask_mat
|
148
|
+
loss = n_samples - masked_prob_mat.sum
|
181
149
|
weighted_prob_mat = masked_prob_mat - prob_mat * masked_prob_mat.sum(1).expand_dims(1)
|
182
150
|
weighted_prob_mat += weighted_prob_mat.transpose
|
183
151
|
weighted_prob_mat[weighted_prob_mat.diag_indices] = -weighted_prob_mat.sum(0)
|
184
|
-
gradient = 2 * z.transpose.dot(weighted_prob_mat).dot(x)
|
185
|
-
|
152
|
+
gradient = -2 * z.transpose.dot(weighted_prob_mat).dot(x)
|
153
|
+
[loss, gradient.flatten.dup]
|
186
154
|
end
|
187
155
|
|
188
156
|
def probability_matrix(z)
|
@@ -0,0 +1,93 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'rumale/base/splitter'
|
4
|
+
require 'rumale/preprocessing/label_encoder'
|
5
|
+
|
6
|
+
module Rumale
|
7
|
+
module ModelSelection
|
8
|
+
# GroupKFold is a class that generates the set of data indices for K-fold cross-validation.
|
9
|
+
# The data points belonging to the same group do not be split into different folds.
|
10
|
+
# The number of groups should be greater than or equal to the number of splits.
|
11
|
+
#
|
12
|
+
# @example
|
13
|
+
# cv = Rumale::ModelSelection::GroupKFold.new(n_splits: 3)
|
14
|
+
# x = Numo::DFloat.new(8, 2).rand
|
15
|
+
# groups = Numo::Int32[1, 1, 1, 2, 2, 3, 3, 3]
|
16
|
+
# cv.split(x, nil, groups).each do |train_ids, test_ids|
|
17
|
+
# puts '---'
|
18
|
+
# pp train_ids
|
19
|
+
# pp test_ids
|
20
|
+
# end
|
21
|
+
#
|
22
|
+
# # ---
|
23
|
+
# # [0, 1, 2, 3, 4]
|
24
|
+
# # [5, 6, 7]
|
25
|
+
# # ---
|
26
|
+
# # [3, 4, 5, 6, 7]
|
27
|
+
# # [0, 1, 2]
|
28
|
+
# # ---
|
29
|
+
# # [0, 1, 2, 5, 6, 7]
|
30
|
+
# # [3, 4]
|
31
|
+
#
|
32
|
+
class GroupKFold
|
33
|
+
include Base::Splitter
|
34
|
+
|
35
|
+
# Return the number of folds.
|
36
|
+
# @return [Integer]
|
37
|
+
attr_reader :n_splits
|
38
|
+
|
39
|
+
# Create a new data splitter for grouped K-fold cross validation.
|
40
|
+
#
|
41
|
+
# @param n_splits [Integer] The number of folds.
|
42
|
+
def initialize(n_splits: 5)
|
43
|
+
check_params_numeric(n_splits: n_splits)
|
44
|
+
@n_splits = n_splits
|
45
|
+
end
|
46
|
+
|
47
|
+
# Generate data indices for grouped K-fold cross validation.
|
48
|
+
#
|
49
|
+
# @overload split(x, y, groups) -> Array
|
50
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features])
|
51
|
+
# The dataset to be used to generate data indices for grouped K-fold cross validation.
|
52
|
+
# @param y [Numo::Int32] (shape: [n_samples])
|
53
|
+
# This argument exists to unify the interface between the K-fold methods, it is not used in the method.
|
54
|
+
# @param groups [Numo::Int32] (shape: [n_samples])
|
55
|
+
# The group labels to be used to generate data indices for grouped K-fold cross validation.
|
56
|
+
# @return [Array] The set of data indices for constructing the training and testing dataset in each fold.
|
57
|
+
def split(x, _y, groups)
|
58
|
+
x = check_convert_sample_array(x)
|
59
|
+
groups = check_convert_label_array(groups)
|
60
|
+
check_sample_label_size(x, groups)
|
61
|
+
|
62
|
+
encoder = Rumale::Preprocessing::LabelEncoder.new
|
63
|
+
groups = encoder.fit_transform(groups)
|
64
|
+
n_groups = encoder.classes.size
|
65
|
+
|
66
|
+
raise ArgumentError, 'The number of groups should be greater than or equal to the number of splits.' if n_groups < @n_splits
|
67
|
+
|
68
|
+
n_samples_per_group = groups.bincount
|
69
|
+
group_ids = n_samples_per_group.sort_index.reverse
|
70
|
+
n_samples_per_group = n_samples_per_group[group_ids]
|
71
|
+
|
72
|
+
n_samples_per_fold = Numo::Int32.zeros(@n_splits)
|
73
|
+
group_to_fold = Numo::Int32.zeros(n_groups)
|
74
|
+
|
75
|
+
n_samples_per_group.each_with_index do |weight, id|
|
76
|
+
min_sample_fold_id = n_samples_per_fold.min_index
|
77
|
+
n_samples_per_fold[min_sample_fold_id] += weight
|
78
|
+
group_to_fold[group_ids[id]] = min_sample_fold_id
|
79
|
+
end
|
80
|
+
|
81
|
+
n_samples = x.shape[0]
|
82
|
+
sample_ids = Array(0...n_samples)
|
83
|
+
fold_ids = group_to_fold[groups]
|
84
|
+
|
85
|
+
Array.new(@n_splits) do |fid|
|
86
|
+
test_ids = fold_ids.eq(fid).where.to_a
|
87
|
+
train_ids = sample_ids - test_ids
|
88
|
+
[train_ids, test_ids]
|
89
|
+
end
|
90
|
+
end
|
91
|
+
end
|
92
|
+
end
|
93
|
+
end
|
@@ -0,0 +1,115 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'rumale/base/splitter'
|
4
|
+
|
5
|
+
module Rumale
|
6
|
+
module ModelSelection
|
7
|
+
# GroupShuffleSplit is a class that generates the set of data indices
|
8
|
+
# for random permutation cross-validation by randomly selecting group labels.
|
9
|
+
#
|
10
|
+
# @example
|
11
|
+
# cv = Rumale::ModelSelection::GroupShuffleSplit.new(n_splits: 2, test_size: 0.2, random_seed: 1)
|
12
|
+
# x = Numo::DFloat.new(8, 2).rand
|
13
|
+
# groups = Numo::Int32[1, 1, 1, 2, 2, 3, 3, 3]
|
14
|
+
# cv.split(x, nil, groups).each do |train_ids, test_ids|
|
15
|
+
# puts '---'
|
16
|
+
# pp train_ids
|
17
|
+
# pp test_ids
|
18
|
+
# end
|
19
|
+
#
|
20
|
+
# # ---
|
21
|
+
# # [0, 1, 2, 5, 6, 7]
|
22
|
+
# # [3, 4]
|
23
|
+
# # ---
|
24
|
+
# # [3, 4, 5, 6, 7]
|
25
|
+
# # [0, 1, 2]
|
26
|
+
#
|
27
|
+
class GroupShuffleSplit
|
28
|
+
include Base::Splitter
|
29
|
+
|
30
|
+
# Return the number of folds.
|
31
|
+
# @return [Integer]
|
32
|
+
attr_reader :n_splits
|
33
|
+
|
34
|
+
# Return the random generator for shuffling the dataset.
|
35
|
+
# @return [Random]
|
36
|
+
attr_reader :rng
|
37
|
+
|
38
|
+
# Create a new data splitter for random permutation cross validation with given group labels.
|
39
|
+
#
|
40
|
+
# @param n_splits [Integer] The number of folds.
|
41
|
+
# @param test_size [Float] The ratio of number of groups for test data.
|
42
|
+
# @param train_size [Float/Nil] The ratio of number of groups for train data.
|
43
|
+
# @param random_seed [Integer] The seed value using to initialize the random generator.
|
44
|
+
def initialize(n_splits: 5, test_size: 0.2, train_size: nil, random_seed: nil)
|
45
|
+
check_params_numeric(n_splits: n_splits, test_size: test_size)
|
46
|
+
check_params_numeric_or_nil(train_size: train_size, random_seed: random_seed)
|
47
|
+
check_params_positive(n_splits: n_splits)
|
48
|
+
check_params_positive(test_size: test_size)
|
49
|
+
check_params_positive(train_size: train_size) unless train_size.nil?
|
50
|
+
@n_splits = n_splits
|
51
|
+
@test_size = test_size
|
52
|
+
@train_size = train_size
|
53
|
+
@random_seed = random_seed
|
54
|
+
@random_seed ||= srand
|
55
|
+
@rng = Random.new(@random_seed)
|
56
|
+
end
|
57
|
+
|
58
|
+
# Generate train and test data indices by randomly selecting group labels.
|
59
|
+
#
|
60
|
+
# @overload split(x, y, groups) -> Array
|
61
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features])
|
62
|
+
# The dataset to be used to generate data indices for random permutation cross validation.
|
63
|
+
# @param y [Numo::Int32] (shape: [n_samples])
|
64
|
+
# This argument exists to unify the interface between the K-fold methods, it is not used in the method.
|
65
|
+
# @param groups [Numo::Int32] (shape: [n_samples])
|
66
|
+
# The group labels to be used to generate data indices for random permutation cross validation.
|
67
|
+
# @return [Array] The set of data indices for constructing the training and testing dataset in each fold.
|
68
|
+
def split(x, _y, groups)
|
69
|
+
x = check_convert_sample_array(x)
|
70
|
+
groups = check_convert_label_array(groups)
|
71
|
+
check_sample_label_size(x, groups)
|
72
|
+
|
73
|
+
classes = groups.to_a.uniq.sort
|
74
|
+
n_groups = classes.size
|
75
|
+
n_test_groups = (@test_size * n_groups).ceil.to_i
|
76
|
+
n_train_groups = @train_size.nil? ? n_groups - n_test_groups : (@train_size * n_groups).floor.to_i
|
77
|
+
|
78
|
+
unless n_test_groups.between?(1, n_groups)
|
79
|
+
raise RangeError,
|
80
|
+
'The number of groups in test split must be not less than 1 and not more than the number of groups.'
|
81
|
+
end
|
82
|
+
unless n_train_groups.between?(1, n_groups)
|
83
|
+
raise RangeError,
|
84
|
+
'The number of groups in train split must be not less than 1 and not more than the number of groups.'
|
85
|
+
end
|
86
|
+
if (n_test_groups + n_train_groups) > n_groups
|
87
|
+
raise RangeError,
|
88
|
+
'The total number of groups in test split and train split must be not more than the number of groups.'
|
89
|
+
end
|
90
|
+
|
91
|
+
sub_rng = @rng.dup
|
92
|
+
|
93
|
+
Array.new(@n_splits) do
|
94
|
+
test_group_ids = classes.sample(n_test_groups, random: sub_rng)
|
95
|
+
train_group_ids = if @train_size.nil?
|
96
|
+
classes - test_group_ids
|
97
|
+
else
|
98
|
+
(classes - test_group_ids).sample(n_train_groups, random: sub_rng)
|
99
|
+
end
|
100
|
+
test_ids = in1d(groups, test_group_ids).where.to_a
|
101
|
+
train_ids = in1d(groups, train_group_ids).where.to_a
|
102
|
+
[train_ids, test_ids]
|
103
|
+
end
|
104
|
+
end
|
105
|
+
|
106
|
+
private
|
107
|
+
|
108
|
+
def in1d(a, b)
|
109
|
+
res = Numo::Bit.zeros(a.shape[0])
|
110
|
+
b.each { |v| res |= a.eq(v) }
|
111
|
+
res
|
112
|
+
end
|
113
|
+
end
|
114
|
+
end
|
115
|
+
end
|
@@ -54,19 +54,19 @@ module Rumale
|
|
54
54
|
x = check_convert_sample_array(x)
|
55
55
|
# Initialize and check some variables.
|
56
56
|
n_samples = x.shape[0]
|
57
|
-
n_test_samples = (@test_size * n_samples).to_i
|
58
|
-
n_train_samples = @train_size.nil? ? n_samples - n_test_samples : (@train_size * n_samples).to_i
|
57
|
+
n_test_samples = (@test_size * n_samples).ceil.to_i
|
58
|
+
n_train_samples = @train_size.nil? ? n_samples - n_test_samples : (@train_size * n_samples).floor.to_i
|
59
59
|
unless @n_splits.between?(1, n_samples)
|
60
60
|
raise ArgumentError,
|
61
61
|
'The value of n_splits must be not less than 1 and not more than the number of samples.'
|
62
62
|
end
|
63
63
|
unless n_test_samples.between?(1, n_samples)
|
64
64
|
raise RangeError,
|
65
|
-
'The number of
|
65
|
+
'The number of samples in test split must be not less than 1 and not more than the number of samples.'
|
66
66
|
end
|
67
67
|
unless n_train_samples.between?(1, n_samples)
|
68
68
|
raise RangeError,
|
69
|
-
'The number of
|
69
|
+
'The number of samples in train split must be not less than 1 and not more than the number of samples.'
|
70
70
|
end
|
71
71
|
if (n_test_samples + n_train_samples) > n_samples
|
72
72
|
raise RangeError,
|
@@ -30,7 +30,7 @@ module Rumale
|
|
30
30
|
# @return [Random]
|
31
31
|
attr_reader :rng
|
32
32
|
|
33
|
-
# Create a new data splitter for K-fold cross validation.
|
33
|
+
# Create a new data splitter for stratified K-fold cross validation.
|
34
34
|
#
|
35
35
|
# @param n_splits [Integer] The number of folds.
|
36
36
|
# @param shuffle [Boolean] The flag indicating whether to shuffle the dataset.
|
@@ -66,15 +66,15 @@ module Rumale
|
|
66
66
|
raise ArgumentError,
|
67
67
|
'The value of n_splits must be not less than 1 and not more than the number of samples in each class.'
|
68
68
|
end
|
69
|
-
unless enough_data_size_each_class?(y, @test_size)
|
69
|
+
unless enough_data_size_each_class?(y, @test_size, 'test')
|
70
70
|
raise RangeError,
|
71
|
-
'The number of
|
71
|
+
'The number of samples in test split must be not less than 1 and not more than the number of samples in each class.'
|
72
72
|
end
|
73
|
-
unless enough_data_size_each_class?(y, train_sz)
|
73
|
+
unless enough_data_size_each_class?(y, train_sz, 'train')
|
74
74
|
raise RangeError,
|
75
|
-
'The number of
|
75
|
+
'The number of samples in train split must be not less than 1 and not more than the number of samples in each class.'
|
76
76
|
end
|
77
|
-
unless enough_data_size_each_class?(y, train_sz + @test_size)
|
77
|
+
unless enough_data_size_each_class?(y, train_sz + @test_size, 'train')
|
78
78
|
raise RangeError,
|
79
79
|
'The total number of samples in test split and train split must be not more than the number of samples in each class.'
|
80
80
|
end
|
@@ -85,12 +85,12 @@ module Rumale
|
|
85
85
|
test_ids = []
|
86
86
|
sample_ids_each_class.each do |sample_ids|
|
87
87
|
n_samples = sample_ids.size
|
88
|
-
n_test_samples = (@test_size * n_samples).to_i
|
89
|
-
n_train_samples = (train_sz * n_samples).to_i
|
88
|
+
n_test_samples = (@test_size * n_samples).ceil.to_i
|
90
89
|
test_ids += sample_ids.sample(n_test_samples, random: sub_rng)
|
91
90
|
train_ids += if @train_size.nil?
|
92
91
|
sample_ids - test_ids
|
93
92
|
else
|
93
|
+
n_train_samples = (train_sz * n_samples).floor.to_i
|
94
94
|
(sample_ids - test_ids).sample(n_train_samples, random: sub_rng)
|
95
95
|
end
|
96
96
|
end
|
@@ -104,9 +104,13 @@ module Rumale
|
|
104
104
|
y.to_a.uniq.map { |label| y.eq(label).where.size }.all? { |n_samples| @n_splits.between?(1, n_samples) }
|
105
105
|
end
|
106
106
|
|
107
|
-
def enough_data_size_each_class?(y, data_size)
|
107
|
+
def enough_data_size_each_class?(y, data_size, data_type)
|
108
108
|
y.to_a.uniq.map { |label| y.eq(label).where.size }.all? do |n_samples|
|
109
|
-
|
109
|
+
if data_type == 'test'
|
110
|
+
(data_size * n_samples).ceil.to_i.between?(1, n_samples)
|
111
|
+
else
|
112
|
+
(data_size * n_samples).floor.to_i.between?(1, n_samples)
|
113
|
+
end
|
110
114
|
end
|
111
115
|
end
|
112
116
|
end
|