rumale 0.20.0 → 0.22.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. checksums.yaml +4 -4
  2. data/.github/workflows/build.yml +23 -0
  3. data/.rubocop.yml +15 -95
  4. data/CHANGELOG.md +28 -0
  5. data/Gemfile +4 -2
  6. data/README.md +5 -2
  7. data/lib/rumale.rb +3 -0
  8. data/lib/rumale/clustering/hdbscan.rb +2 -2
  9. data/lib/rumale/clustering/snn.rb +1 -1
  10. data/lib/rumale/dataset.rb +1 -1
  11. data/lib/rumale/decomposition/nmf.rb +2 -2
  12. data/lib/rumale/ensemble/random_forest_classifier.rb +1 -1
  13. data/lib/rumale/ensemble/random_forest_regressor.rb +1 -1
  14. data/lib/rumale/evaluation_measure/roc_auc.rb +3 -0
  15. data/lib/rumale/feature_extraction/feature_hasher.rb +1 -1
  16. data/lib/rumale/feature_extraction/hash_vectorizer.rb +1 -1
  17. data/lib/rumale/linear_model/base_sgd.rb +1 -1
  18. data/lib/rumale/linear_model/elastic_net.rb +2 -2
  19. data/lib/rumale/linear_model/lasso.rb +2 -2
  20. data/lib/rumale/linear_model/linear_regression.rb +2 -2
  21. data/lib/rumale/linear_model/logistic_regression.rb +123 -35
  22. data/lib/rumale/linear_model/ridge.rb +2 -2
  23. data/lib/rumale/linear_model/svc.rb +2 -2
  24. data/lib/rumale/linear_model/svr.rb +2 -2
  25. data/lib/rumale/manifold/tsne.rb +1 -1
  26. data/lib/rumale/metric_learning/neighbourhood_component_analysis.rb +13 -45
  27. data/lib/rumale/model_selection/group_k_fold.rb +93 -0
  28. data/lib/rumale/model_selection/group_shuffle_split.rb +115 -0
  29. data/lib/rumale/model_selection/shuffle_split.rb +4 -4
  30. data/lib/rumale/model_selection/stratified_k_fold.rb +1 -1
  31. data/lib/rumale/model_selection/stratified_shuffle_split.rb +13 -9
  32. data/lib/rumale/model_selection/time_series_split.rb +91 -0
  33. data/lib/rumale/pipeline/pipeline.rb +1 -1
  34. data/lib/rumale/probabilistic_output.rb +1 -1
  35. data/lib/rumale/tree/base_decision_tree.rb +2 -9
  36. data/lib/rumale/tree/gradient_tree_regressor.rb +3 -10
  37. data/lib/rumale/version.rb +1 -1
  38. data/rumale.gemspec +1 -0
  39. metadata +21 -4
  40. data/.coveralls.yml +0 -1
@@ -10,7 +10,7 @@ module Rumale
10
10
  #
11
11
  # @example
12
12
  # estimator =
13
- # Rumale::LinearModel::Ridge.new(reg_param: 0.1, max_iter: 500, batch_size: 20, random_seed: 1)
13
+ # Rumale::LinearModel::Ridge.new(reg_param: 0.1, max_iter: 1000, batch_size: 20, random_seed: 1)
14
14
  # estimator.fit(training_samples, traininig_values)
15
15
  # results = estimator.predict(testing_samples)
16
16
  #
@@ -70,7 +70,7 @@ module Rumale
70
70
  # @param random_seed [Integer] The seed value using to initialize the random generator.
71
71
  def initialize(learning_rate: 0.01, decay: nil, momentum: 0.9,
72
72
  reg_param: 1.0, fit_bias: true, bias_scale: 1.0,
73
- max_iter: 200, batch_size: 50, tol: 1e-4,
73
+ max_iter: 1000, batch_size: 50, tol: 1e-4,
74
74
  solver: 'auto',
75
75
  n_jobs: nil, verbose: false, random_seed: nil)
76
76
  check_params_numeric(learning_rate: learning_rate, momentum: momentum,
@@ -17,7 +17,7 @@ module Rumale
17
17
  #
18
18
  # @example
19
19
  # estimator =
20
- # Rumale::LinearModel::SVC.new(reg_param: 1.0, max_iter: 200, batch_size: 50, random_seed: 1)
20
+ # Rumale::LinearModel::SVC.new(reg_param: 1.0, max_iter: 1000, batch_size: 50, random_seed: 1)
21
21
  # estimator.fit(training_samples, traininig_labels)
22
22
  # results = estimator.predict(testing_samples)
23
23
  #
@@ -74,7 +74,7 @@ module Rumale
74
74
  def initialize(learning_rate: 0.01, decay: nil, momentum: 0.9,
75
75
  penalty: 'l2', reg_param: 1.0, l1_ratio: 0.5,
76
76
  fit_bias: true, bias_scale: 1.0,
77
- max_iter: 200, batch_size: 50, tol: 1e-4,
77
+ max_iter: 1000, batch_size: 50, tol: 1e-4,
78
78
  probability: false,
79
79
  n_jobs: nil, verbose: false, random_seed: nil)
80
80
  check_params_numeric(learning_rate: learning_rate, momentum: momentum,
@@ -14,7 +14,7 @@ module Rumale
14
14
  #
15
15
  # @example
16
16
  # estimator =
17
- # Rumale::LinearModel::SVR.new(reg_param: 1.0, epsilon: 0.1, max_iter: 200, batch_size: 50, random_seed: 1)
17
+ # Rumale::LinearModel::SVR.new(reg_param: 1.0, epsilon: 0.1, max_iter: 1000, batch_size: 50, random_seed: 1)
18
18
  # estimator.fit(training_samples, traininig_target_values)
19
19
  # results = estimator.predict(testing_samples)
20
20
  #
@@ -68,7 +68,7 @@ module Rumale
68
68
  penalty: 'l2', reg_param: 1.0, l1_ratio: 0.5,
69
69
  fit_bias: true, bias_scale: 1.0,
70
70
  epsilon: 0.1,
71
- max_iter: 200, batch_size: 50, tol: 1e-4,
71
+ max_iter: 1000, batch_size: 50, tol: 1e-4,
72
72
  n_jobs: nil, verbose: false, random_seed: nil)
73
73
  check_params_numeric(learning_rate: learning_rate, momentum: momentum,
74
74
  reg_param: reg_param, bias_scale: bias_scale, epsilon: epsilon,
@@ -102,7 +102,7 @@ module Rumale
102
102
  break if terminate?(hi_prob_mat, lo_prob_mat)
103
103
 
104
104
  a = hi_prob_mat * lo_prob_mat
105
- b = lo_prob_mat * lo_prob_mat
105
+ b = lo_prob_mat**2
106
106
  y = (b.dot(one_vec) * y + (a - b).dot(y)) / a.dot(one_vec)
107
107
  lo_prob_mat = t_distributed_probability_matrix(y)
108
108
  @n_iter = t + 1
@@ -2,13 +2,13 @@
2
2
 
3
3
  require 'rumale/base/base_estimator'
4
4
  require 'rumale/base/transformer'
5
+ require 'lbfgsb'
5
6
 
6
7
  module Rumale
7
8
  module MetricLearning
8
9
  # NeighbourhoodComponentAnalysis is a class that implements Neighbourhood Component Analysis.
9
10
  #
10
11
  # @example
11
- # require 'mopti'
12
12
  # require 'rumale'
13
13
  #
14
14
  # transformer = Rumale::MetricLearning::NeighbourhoodComponentAnalysis.new
@@ -39,7 +39,9 @@ module Rumale
39
39
  # @param init [String] The initialization method for components ('random' or 'pca').
40
40
  # @param max_iter [Integer] The maximum number of iterations.
41
41
  # @param tol [Float] The tolerance of termination criterion.
42
+ # This value is given as tol / Lbfgsb::DBL_EPSILON to the factr argument of Lbfgsb.minimize method.
42
43
  # @param verbose [Boolean] The flag indicating whether to output loss during iteration.
44
+ # If true is given, 'iterate.dat' file is generated by lbfgsb.rb.
43
45
  # @param random_seed [Integer] The seed value using to initialize the random generator.
44
46
  def initialize(n_components: nil, init: 'random', max_iter: 100, tol: 1e-6, verbose: false, random_seed: nil)
45
47
  check_params_numeric_or_nil(n_components: n_components, random_seed: random_seed)
@@ -65,8 +67,6 @@ module Rumale
65
67
  # @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
66
68
  # @return [NeighbourhoodComponentAnalysis] The learned classifier itself.
67
69
  def fit(x, y)
68
- raise 'NeighbourhoodComponentAnalysis#fit requires Mopti but that is not loaded.' unless enable_mopti?
69
-
70
70
  x = check_convert_sample_array(x)
71
71
  y = check_convert_label_array(y)
72
72
  check_sample_label_size(x, y)
@@ -102,17 +102,9 @@ module Rumale
102
102
 
103
103
  private
104
104
 
105
- def enable_mopti?
106
- if defined?(Mopti).nil?
107
- warn('NeighbourhoodComponentAnalysis#fit requires Mopti but that is not loaded. You should intall and load mopti gem in advance.')
108
- return false
109
- end
110
- true
111
- end
112
-
113
105
  def init_components(x, n_features, n_components)
114
106
  if @params[:init] == 'pca'
115
- pca = Rumale::Decomposition::PCA.new(n_components: n_components, solver: 'evd')
107
+ pca = Rumale::Decomposition::PCA.new(n_components: n_components)
116
108
  pca.fit(x).components.flatten.dup
117
109
  else
118
110
  Rumale::Utils.rand_normal([n_features, n_components], @rng.dup).flatten.dup
@@ -127,28 +119,18 @@ module Rumale
127
119
  res[:x] = comp_init
128
120
  res[:n_iter] = 0
129
121
  # perform optimization.
130
- optimizer = Mopti::ScaledConjugateGradient.new(
131
- fnc: method(:nca_loss), jcb: method(:nca_dloss),
132
- x_init: comp_init, args: [x, y],
133
- max_iter: @params[:max_iter], ftol: @params[:tol]
122
+ verbose = @params[:verbose] ? 1 : -1
123
+ res = Lbfgsb.minimize(
124
+ fnc: method(:nca_fnc), jcb: true, x_init: comp_init, args: [x, y],
125
+ maxiter: @params[:max_iter], factr: @params[:tol] / Lbfgsb::DBL_EPSILON, verbose: verbose
134
126
  )
135
- fold = 0.0
136
- dold = 0.0
137
- optimizer.each do |prm|
138
- res = prm
139
- puts "[NeighbourhoodComponentAnalysis] The value of objective function after #{res[:n_iter]} epochs: #{x.shape[0] - res[:fnc]}" if @params[:verbose]
140
- break if (fold - res[:fnc]).abs <= @params[:tol] && (dold - res[:jcb]).abs <= @params[:tol]
141
-
142
- fold = res[:fnc]
143
- dold = res[:jcb]
144
- end
145
127
  # return the results.
146
128
  n_iter = res[:n_iter]
147
129
  comps = n_components == 1 ? res[:x].dup : res[:x].reshape(n_components, n_features)
148
130
  [comps, n_iter]
149
131
  end
150
132
 
151
- def nca_loss(w, x, y)
133
+ def nca_fnc(w, x, y)
152
134
  # initialize some variables.
153
135
  n_samples, n_features = x.shape
154
136
  n_components = w.size / n_features
@@ -157,32 +139,18 @@ module Rumale
157
139
  z = x.dot(w.transpose)
158
140
  # calculate probability matrix.
159
141
  prob_mat = probability_matrix(z)
160
- # calculate loss.
142
+ # calculate loss and gradient.
161
143
  # NOTE:
162
144
  # NCA attempts to maximize its objective function.
163
145
  # For the minization algorithm, the objective function value is subtracted from the maixmum value (n_samples).
164
146
  mask_mat = y.expand_dims(1).eq(y)
165
147
  masked_prob_mat = prob_mat * mask_mat
166
- n_samples - masked_prob_mat.sum
167
- end
168
-
169
- def nca_dloss(w, x, y)
170
- # initialize some variables.
171
- n_features = x.shape[1]
172
- n_components = w.size / n_features
173
- # projection.
174
- w = w.reshape(n_components, n_features)
175
- z = x.dot(w.transpose)
176
- # calculate probability matrix.
177
- prob_mat = probability_matrix(z)
178
- # calculate gradient.
179
- mask_mat = y.expand_dims(1).eq(y)
180
- masked_prob_mat = prob_mat * mask_mat
148
+ loss = n_samples - masked_prob_mat.sum
181
149
  weighted_prob_mat = masked_prob_mat - prob_mat * masked_prob_mat.sum(1).expand_dims(1)
182
150
  weighted_prob_mat += weighted_prob_mat.transpose
183
151
  weighted_prob_mat[weighted_prob_mat.diag_indices] = -weighted_prob_mat.sum(0)
184
- gradient = 2 * z.transpose.dot(weighted_prob_mat).dot(x)
185
- -gradient.flatten.dup
152
+ gradient = -2 * z.transpose.dot(weighted_prob_mat).dot(x)
153
+ [loss, gradient.flatten.dup]
186
154
  end
187
155
 
188
156
  def probability_matrix(z)
@@ -0,0 +1,93 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/splitter'
4
+ require 'rumale/preprocessing/label_encoder'
5
+
6
+ module Rumale
7
+ module ModelSelection
8
+ # GroupKFold is a class that generates the set of data indices for K-fold cross-validation.
9
+ # The data points belonging to the same group do not be split into different folds.
10
+ # The number of groups should be greater than or equal to the number of splits.
11
+ #
12
+ # @example
13
+ # cv = Rumale::ModelSelection::GroupKFold.new(n_splits: 3)
14
+ # x = Numo::DFloat.new(8, 2).rand
15
+ # groups = Numo::Int32[1, 1, 1, 2, 2, 3, 3, 3]
16
+ # cv.split(x, nil, groups).each do |train_ids, test_ids|
17
+ # puts '---'
18
+ # pp train_ids
19
+ # pp test_ids
20
+ # end
21
+ #
22
+ # # ---
23
+ # # [0, 1, 2, 3, 4]
24
+ # # [5, 6, 7]
25
+ # # ---
26
+ # # [3, 4, 5, 6, 7]
27
+ # # [0, 1, 2]
28
+ # # ---
29
+ # # [0, 1, 2, 5, 6, 7]
30
+ # # [3, 4]
31
+ #
32
+ class GroupKFold
33
+ include Base::Splitter
34
+
35
+ # Return the number of folds.
36
+ # @return [Integer]
37
+ attr_reader :n_splits
38
+
39
+ # Create a new data splitter for grouped K-fold cross validation.
40
+ #
41
+ # @param n_splits [Integer] The number of folds.
42
+ def initialize(n_splits: 5)
43
+ check_params_numeric(n_splits: n_splits)
44
+ @n_splits = n_splits
45
+ end
46
+
47
+ # Generate data indices for grouped K-fold cross validation.
48
+ #
49
+ # @overload split(x, y, groups) -> Array
50
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features])
51
+ # The dataset to be used to generate data indices for grouped K-fold cross validation.
52
+ # @param y [Numo::Int32] (shape: [n_samples])
53
+ # This argument exists to unify the interface between the K-fold methods, it is not used in the method.
54
+ # @param groups [Numo::Int32] (shape: [n_samples])
55
+ # The group labels to be used to generate data indices for grouped K-fold cross validation.
56
+ # @return [Array] The set of data indices for constructing the training and testing dataset in each fold.
57
+ def split(x, _y, groups)
58
+ x = check_convert_sample_array(x)
59
+ groups = check_convert_label_array(groups)
60
+ check_sample_label_size(x, groups)
61
+
62
+ encoder = Rumale::Preprocessing::LabelEncoder.new
63
+ groups = encoder.fit_transform(groups)
64
+ n_groups = encoder.classes.size
65
+
66
+ raise ArgumentError, 'The number of groups should be greater than or equal to the number of splits.' if n_groups < @n_splits
67
+
68
+ n_samples_per_group = groups.bincount
69
+ group_ids = n_samples_per_group.sort_index.reverse
70
+ n_samples_per_group = n_samples_per_group[group_ids]
71
+
72
+ n_samples_per_fold = Numo::Int32.zeros(@n_splits)
73
+ group_to_fold = Numo::Int32.zeros(n_groups)
74
+
75
+ n_samples_per_group.each_with_index do |weight, id|
76
+ min_sample_fold_id = n_samples_per_fold.min_index
77
+ n_samples_per_fold[min_sample_fold_id] += weight
78
+ group_to_fold[group_ids[id]] = min_sample_fold_id
79
+ end
80
+
81
+ n_samples = x.shape[0]
82
+ sample_ids = Array(0...n_samples)
83
+ fold_ids = group_to_fold[groups]
84
+
85
+ Array.new(@n_splits) do |fid|
86
+ test_ids = fold_ids.eq(fid).where.to_a
87
+ train_ids = sample_ids - test_ids
88
+ [train_ids, test_ids]
89
+ end
90
+ end
91
+ end
92
+ end
93
+ end
@@ -0,0 +1,115 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/splitter'
4
+
5
+ module Rumale
6
+ module ModelSelection
7
+ # GroupShuffleSplit is a class that generates the set of data indices
8
+ # for random permutation cross-validation by randomly selecting group labels.
9
+ #
10
+ # @example
11
+ # cv = Rumale::ModelSelection::GroupShuffleSplit.new(n_splits: 2, test_size: 0.2, random_seed: 1)
12
+ # x = Numo::DFloat.new(8, 2).rand
13
+ # groups = Numo::Int32[1, 1, 1, 2, 2, 3, 3, 3]
14
+ # cv.split(x, nil, groups).each do |train_ids, test_ids|
15
+ # puts '---'
16
+ # pp train_ids
17
+ # pp test_ids
18
+ # end
19
+ #
20
+ # # ---
21
+ # # [0, 1, 2, 5, 6, 7]
22
+ # # [3, 4]
23
+ # # ---
24
+ # # [3, 4, 5, 6, 7]
25
+ # # [0, 1, 2]
26
+ #
27
+ class GroupShuffleSplit
28
+ include Base::Splitter
29
+
30
+ # Return the number of folds.
31
+ # @return [Integer]
32
+ attr_reader :n_splits
33
+
34
+ # Return the random generator for shuffling the dataset.
35
+ # @return [Random]
36
+ attr_reader :rng
37
+
38
+ # Create a new data splitter for random permutation cross validation with given group labels.
39
+ #
40
+ # @param n_splits [Integer] The number of folds.
41
+ # @param test_size [Float] The ratio of number of groups for test data.
42
+ # @param train_size [Float/Nil] The ratio of number of groups for train data.
43
+ # @param random_seed [Integer] The seed value using to initialize the random generator.
44
+ def initialize(n_splits: 5, test_size: 0.2, train_size: nil, random_seed: nil)
45
+ check_params_numeric(n_splits: n_splits, test_size: test_size)
46
+ check_params_numeric_or_nil(train_size: train_size, random_seed: random_seed)
47
+ check_params_positive(n_splits: n_splits)
48
+ check_params_positive(test_size: test_size)
49
+ check_params_positive(train_size: train_size) unless train_size.nil?
50
+ @n_splits = n_splits
51
+ @test_size = test_size
52
+ @train_size = train_size
53
+ @random_seed = random_seed
54
+ @random_seed ||= srand
55
+ @rng = Random.new(@random_seed)
56
+ end
57
+
58
+ # Generate train and test data indices by randomly selecting group labels.
59
+ #
60
+ # @overload split(x, y, groups) -> Array
61
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features])
62
+ # The dataset to be used to generate data indices for random permutation cross validation.
63
+ # @param y [Numo::Int32] (shape: [n_samples])
64
+ # This argument exists to unify the interface between the K-fold methods, it is not used in the method.
65
+ # @param groups [Numo::Int32] (shape: [n_samples])
66
+ # The group labels to be used to generate data indices for random permutation cross validation.
67
+ # @return [Array] The set of data indices for constructing the training and testing dataset in each fold.
68
+ def split(x, _y, groups)
69
+ x = check_convert_sample_array(x)
70
+ groups = check_convert_label_array(groups)
71
+ check_sample_label_size(x, groups)
72
+
73
+ classes = groups.to_a.uniq.sort
74
+ n_groups = classes.size
75
+ n_test_groups = (@test_size * n_groups).ceil.to_i
76
+ n_train_groups = @train_size.nil? ? n_groups - n_test_groups : (@train_size * n_groups).floor.to_i
77
+
78
+ unless n_test_groups.between?(1, n_groups)
79
+ raise RangeError,
80
+ 'The number of groups in test split must be not less than 1 and not more than the number of groups.'
81
+ end
82
+ unless n_train_groups.between?(1, n_groups)
83
+ raise RangeError,
84
+ 'The number of groups in train split must be not less than 1 and not more than the number of groups.'
85
+ end
86
+ if (n_test_groups + n_train_groups) > n_groups
87
+ raise RangeError,
88
+ 'The total number of groups in test split and train split must be not more than the number of groups.'
89
+ end
90
+
91
+ sub_rng = @rng.dup
92
+
93
+ Array.new(@n_splits) do
94
+ test_group_ids = classes.sample(n_test_groups, random: sub_rng)
95
+ train_group_ids = if @train_size.nil?
96
+ classes - test_group_ids
97
+ else
98
+ (classes - test_group_ids).sample(n_train_groups, random: sub_rng)
99
+ end
100
+ test_ids = in1d(groups, test_group_ids).where.to_a
101
+ train_ids = in1d(groups, train_group_ids).where.to_a
102
+ [train_ids, test_ids]
103
+ end
104
+ end
105
+
106
+ private
107
+
108
+ def in1d(a, b)
109
+ res = Numo::Bit.zeros(a.shape[0])
110
+ b.each { |v| res |= a.eq(v) }
111
+ res
112
+ end
113
+ end
114
+ end
115
+ end
@@ -54,19 +54,19 @@ module Rumale
54
54
  x = check_convert_sample_array(x)
55
55
  # Initialize and check some variables.
56
56
  n_samples = x.shape[0]
57
- n_test_samples = (@test_size * n_samples).to_i
58
- n_train_samples = @train_size.nil? ? n_samples - n_test_samples : (@train_size * n_samples).to_i
57
+ n_test_samples = (@test_size * n_samples).ceil.to_i
58
+ n_train_samples = @train_size.nil? ? n_samples - n_test_samples : (@train_size * n_samples).floor.to_i
59
59
  unless @n_splits.between?(1, n_samples)
60
60
  raise ArgumentError,
61
61
  'The value of n_splits must be not less than 1 and not more than the number of samples.'
62
62
  end
63
63
  unless n_test_samples.between?(1, n_samples)
64
64
  raise RangeError,
65
- 'The number of sample in test split must be not less than 1 and not more than the number of samples.'
65
+ 'The number of samples in test split must be not less than 1 and not more than the number of samples.'
66
66
  end
67
67
  unless n_train_samples.between?(1, n_samples)
68
68
  raise RangeError,
69
- 'The number of sample in train split must be not less than 1 and not more than the number of samples.'
69
+ 'The number of samples in train split must be not less than 1 and not more than the number of samples.'
70
70
  end
71
71
  if (n_test_samples + n_train_samples) > n_samples
72
72
  raise RangeError,
@@ -30,7 +30,7 @@ module Rumale
30
30
  # @return [Random]
31
31
  attr_reader :rng
32
32
 
33
- # Create a new data splitter for K-fold cross validation.
33
+ # Create a new data splitter for stratified K-fold cross validation.
34
34
  #
35
35
  # @param n_splits [Integer] The number of folds.
36
36
  # @param shuffle [Boolean] The flag indicating whether to shuffle the dataset.
@@ -66,15 +66,15 @@ module Rumale
66
66
  raise ArgumentError,
67
67
  'The value of n_splits must be not less than 1 and not more than the number of samples in each class.'
68
68
  end
69
- unless enough_data_size_each_class?(y, @test_size)
69
+ unless enough_data_size_each_class?(y, @test_size, 'test')
70
70
  raise RangeError,
71
- 'The number of sample in test split must be not less than 1 and not more than the number of samples in each class.'
71
+ 'The number of samples in test split must be not less than 1 and not more than the number of samples in each class.'
72
72
  end
73
- unless enough_data_size_each_class?(y, train_sz)
73
+ unless enough_data_size_each_class?(y, train_sz, 'train')
74
74
  raise RangeError,
75
- 'The number of sample in train split must be not less than 1 and not more than the number of samples in each class.'
75
+ 'The number of samples in train split must be not less than 1 and not more than the number of samples in each class.'
76
76
  end
77
- unless enough_data_size_each_class?(y, train_sz + @test_size)
77
+ unless enough_data_size_each_class?(y, train_sz + @test_size, 'train')
78
78
  raise RangeError,
79
79
  'The total number of samples in test split and train split must be not more than the number of samples in each class.'
80
80
  end
@@ -85,12 +85,12 @@ module Rumale
85
85
  test_ids = []
86
86
  sample_ids_each_class.each do |sample_ids|
87
87
  n_samples = sample_ids.size
88
- n_test_samples = (@test_size * n_samples).to_i
89
- n_train_samples = (train_sz * n_samples).to_i
88
+ n_test_samples = (@test_size * n_samples).ceil.to_i
90
89
  test_ids += sample_ids.sample(n_test_samples, random: sub_rng)
91
90
  train_ids += if @train_size.nil?
92
91
  sample_ids - test_ids
93
92
  else
93
+ n_train_samples = (train_sz * n_samples).floor.to_i
94
94
  (sample_ids - test_ids).sample(n_train_samples, random: sub_rng)
95
95
  end
96
96
  end
@@ -104,9 +104,13 @@ module Rumale
104
104
  y.to_a.uniq.map { |label| y.eq(label).where.size }.all? { |n_samples| @n_splits.between?(1, n_samples) }
105
105
  end
106
106
 
107
- def enough_data_size_each_class?(y, data_size)
107
+ def enough_data_size_each_class?(y, data_size, data_type)
108
108
  y.to_a.uniq.map { |label| y.eq(label).where.size }.all? do |n_samples|
109
- (data_size * n_samples).to_i.between?(1, n_samples)
109
+ if data_type == 'test'
110
+ (data_size * n_samples).ceil.to_i.between?(1, n_samples)
111
+ else
112
+ (data_size * n_samples).floor.to_i.between?(1, n_samples)
113
+ end
110
114
  end
111
115
  end
112
116
  end