rumale 0.19.0 → 0.20.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. checksums.yaml +4 -4
  2. data/.rubocop.yml +5 -29
  3. data/CHANGELOG.md +28 -0
  4. data/lib/rumale.rb +7 -10
  5. data/lib/rumale/clustering/hdbscan.rb +3 -3
  6. data/lib/rumale/clustering/k_means.rb +1 -1
  7. data/lib/rumale/clustering/k_medoids.rb +1 -1
  8. data/lib/rumale/clustering/mini_batch_k_means.rb +139 -0
  9. data/lib/rumale/dataset.rb +4 -4
  10. data/lib/rumale/decomposition/nmf.rb +2 -2
  11. data/lib/rumale/ensemble/random_forest_classifier.rb +1 -1
  12. data/lib/rumale/ensemble/random_forest_regressor.rb +1 -1
  13. data/lib/rumale/feature_extraction/feature_hasher.rb +1 -1
  14. data/lib/rumale/feature_extraction/hash_vectorizer.rb +1 -1
  15. data/lib/rumale/feature_extraction/tfidf_transformer.rb +113 -0
  16. data/lib/rumale/kernel_approximation/nystroem.rb +1 -1
  17. data/lib/rumale/kernel_machine/kernel_svc.rb +1 -1
  18. data/lib/rumale/linear_model/base_sgd.rb +1 -1
  19. data/lib/rumale/manifold/tsne.rb +1 -1
  20. data/lib/rumale/model_selection/cross_validation.rb +3 -2
  21. data/lib/rumale/model_selection/group_k_fold.rb +93 -0
  22. data/lib/rumale/model_selection/group_shuffle_split.rb +115 -0
  23. data/lib/rumale/model_selection/k_fold.rb +1 -1
  24. data/lib/rumale/model_selection/shuffle_split.rb +5 -5
  25. data/lib/rumale/model_selection/stratified_k_fold.rb +1 -1
  26. data/lib/rumale/model_selection/stratified_shuffle_split.rb +13 -9
  27. data/lib/rumale/multiclass/one_vs_rest_classifier.rb +2 -2
  28. data/lib/rumale/nearest_neighbors/vp_tree.rb +1 -1
  29. data/lib/rumale/neural_network/adam.rb +1 -1
  30. data/lib/rumale/neural_network/base_mlp.rb +1 -1
  31. data/lib/rumale/preprocessing/binarizer.rb +60 -0
  32. data/lib/rumale/preprocessing/l1_normalizer.rb +62 -0
  33. data/lib/rumale/preprocessing/l2_normalizer.rb +2 -1
  34. data/lib/rumale/preprocessing/max_normalizer.rb +62 -0
  35. data/lib/rumale/probabilistic_output.rb +1 -1
  36. data/lib/rumale/version.rb +1 -1
  37. metadata +12 -15
  38. data/lib/rumale/linear_model/base_linear_model.rb +0 -102
  39. data/lib/rumale/optimizer/ada_grad.rb +0 -42
  40. data/lib/rumale/optimizer/adam.rb +0 -56
  41. data/lib/rumale/optimizer/nadam.rb +0 -67
  42. data/lib/rumale/optimizer/rmsprop.rb +0 -50
  43. data/lib/rumale/optimizer/sgd.rb +0 -46
  44. data/lib/rumale/optimizer/yellow_fin.rb +0 -104
  45. data/lib/rumale/polynomial_model/base_factorization_machine.rb +0 -125
  46. data/lib/rumale/polynomial_model/factorization_machine_classifier.rb +0 -220
  47. data/lib/rumale/polynomial_model/factorization_machine_regressor.rb +0 -134
@@ -0,0 +1,113 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/base_estimator'
4
+ require 'rumale/base/transformer'
5
+ require 'rumale/preprocessing/l1_normalizer'
6
+ require 'rumale/preprocessing/l2_normalizer'
7
+
8
+ module Rumale
9
+ module FeatureExtraction
10
+ # Transform sample matrix with term frequecy (tf) to a normalized tf-idf (inverse document frequency) reprensentation.
11
+ #
12
+ # @example
13
+ # encoder = Rumale::FeatureExtraction::HashVectorizer.new
14
+ # x = encoder.fit_transform([
15
+ # { foo: 1, bar: 2 },
16
+ # { foo: 3, baz: 1 }
17
+ # ])
18
+ #
19
+ # # > pp x
20
+ # # Numo::DFloat#shape=[2,3]
21
+ # # [[2, 0, 1],
22
+ # # [0, 1, 3]]
23
+ #
24
+ # transformer = Rumale::FeatureExtraction::TfidfTransformer.new
25
+ # x_tfidf = transformer.fit_transform(x)
26
+ #
27
+ # # > pp x_tfidf
28
+ # # Numo::DFloat#shape=[2,3]
29
+ # # [[0.959056, 0, 0.283217],
30
+ # # [0, 0.491506, 0.870874]]
31
+ #
32
+ # *Reference*
33
+ # - Manning, C D., Raghavan, P., and Schutze, H., "Introduction to Information Retrieval," Cambridge University Press., 2008.
34
+ class TfidfTransformer
35
+ include Base::BaseEstimator
36
+ include Base::Transformer
37
+
38
+ # Return the vector consists of inverse document frequency.
39
+ # @return [Numo::DFloat] (shape: [n_features])
40
+ attr_reader :idf
41
+
42
+ # Create a new transfomer for converting tf vectors to tf-idf vectors.
43
+ #
44
+ # @param norm [String] The normalization method to be used ('l1', 'l2' and 'none').
45
+ # @param use_idf [Boolean] The flag indicating whether to use inverse document frequency weighting.
46
+ # @param smooth_idf [Boolean] The flag indicating whether to apply idf smoothing by log((n_samples + 1) / (df + 1)) + 1.
47
+ # @param sublinear_tf [Boolean] The flag indicating whether to perform subliner tf scaling by 1 + log(tf).
48
+ def initialize(norm: 'l2', use_idf: true, smooth_idf: false, sublinear_tf: false)
49
+ check_params_string(norm: norm)
50
+ check_params_boolean(use_idf: use_idf, smooth_idf: smooth_idf, sublinear_tf: sublinear_tf)
51
+ @params = {}
52
+ @params[:norm] = norm
53
+ @params[:use_idf] = use_idf
54
+ @params[:smooth_idf] = smooth_idf
55
+ @params[:sublinear_tf] = sublinear_tf
56
+ @idf = nil
57
+ end
58
+
59
+ # Calculate the inverse document frequency for weighting.
60
+ #
61
+ # @overload fit(x) -> TfidfTransformer
62
+ #
63
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to calculate the idf values.
64
+ # @return [TfidfTransformer]
65
+ def fit(x, _y = nil)
66
+ return self unless @params[:use_idf]
67
+
68
+ x = check_convert_sample_array(x)
69
+
70
+ n_samples = x.shape[0]
71
+ df = x.class.cast(x.gt(0.0).count(0))
72
+
73
+ if @params[:smooth_idf]
74
+ df += 1
75
+ n_samples += 1
76
+ end
77
+
78
+ @idf = Numo::NMath.log(n_samples / df) + 1
79
+
80
+ self
81
+ end
82
+
83
+ # Calculate the idf values, and then transfrom samples to the tf-idf representation.
84
+ #
85
+ # @overload fit_transform(x) -> Numo::DFloat
86
+ #
87
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to calculate idf and be transformed to tf-idf representation.
88
+ # @return [Numo::DFloat] The transformed samples.
89
+ def fit_transform(x, _y = nil)
90
+ fit(x).transform(x)
91
+ end
92
+
93
+ # Perform transforming the given samples to the tf-idf representation.
94
+ #
95
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to be transformed.
96
+ # @return [Numo::DFloat] The transformed samples.
97
+ def transform(x)
98
+ x = check_convert_sample_array(x)
99
+ z = x.dup
100
+
101
+ z[z.ne(0)] = Numo::NMath.log(z[z.ne(0)]) + 1 if @params[:sublinear_tf]
102
+ z *= @idf if @params[:use_idf]
103
+ case @params[:norm]
104
+ when 'l2'
105
+ z = Rumale::Preprocessing::L2Normalizer.new.fit_transform(z)
106
+ when 'l1'
107
+ z = Rumale::Preprocessing::L1Normalizer.new.fit_transform(z)
108
+ end
109
+ z
110
+ end
111
+ end
112
+ end
113
+ end
@@ -69,7 +69,7 @@ module Rumale
69
69
  n_components = [1, [@params[:n_components], n_samples].min].max
70
70
 
71
71
  # random sampling.
72
- @component_indices = Numo::Int32.cast([*0...n_samples].shuffle(random: sub_rng)[0...n_components])
72
+ @component_indices = Numo::Int32.cast(Array(0...n_samples).shuffle(random: sub_rng)[0...n_components])
73
73
  @components = x[@component_indices, true]
74
74
 
75
75
  # calculate normalizing factor.
@@ -172,7 +172,7 @@ module Rumale
172
172
  # Start optimization.
173
173
  @params[:max_iter].times do |t|
174
174
  # random sampling
175
- rand_ids = [*0...n_training_samples].shuffle(random: sub_rng) if rand_ids.empty?
175
+ rand_ids = Array(0...n_training_samples).shuffle(random: sub_rng) if rand_ids.empty?
176
176
  target_id = rand_ids.shift
177
177
  # update the weight vector
178
178
  func = (weight_vec * bin_y).dot(x[target_id, true].transpose).to_f
@@ -209,7 +209,7 @@ module Rumale
209
209
  l1_penalty = LinearModel::Penalty::L1Penalty.new(reg_param: l1_reg_param) if apply_l1_penalty?
210
210
  # Optimization.
211
211
  @params[:max_iter].times do |t|
212
- sample_ids = [*0...n_samples]
212
+ sample_ids = Array(0...n_samples)
213
213
  sample_ids.shuffle!(random: sub_rng)
214
214
  until (subset_ids = sample_ids.shift(@params[:batch_size])).empty?
215
215
  # sampling
@@ -102,7 +102,7 @@ module Rumale
102
102
  break if terminate?(hi_prob_mat, lo_prob_mat)
103
103
 
104
104
  a = hi_prob_mat * lo_prob_mat
105
- b = lo_prob_mat * lo_prob_mat
105
+ b = lo_prob_mat**2
106
106
  y = (b.dot(one_vec) * y + (a - b).dot(y)) / a.dot(one_vec)
107
107
  lo_prob_mat = t_distributed_probability_matrix(y)
108
108
  @n_iter = t + 1
@@ -69,10 +69,11 @@ module Rumale
69
69
  # the return_train_score is false.
70
70
  def perform(x, y)
71
71
  x = check_convert_sample_array(x)
72
- if @estimator.is_a?(Rumale::Base::Classifier)
72
+ case @estimator
73
+ when Rumale::Base::Classifier
73
74
  y = check_convert_label_array(y)
74
75
  check_sample_label_size(x, y)
75
- elsif @estimator.is_a?(Rumale::Base::Regressor)
76
+ when Rumale::Base::Regressor
76
77
  y = check_convert_tvalue_array(y)
77
78
  check_sample_tvalue_size(x, y)
78
79
  else
@@ -0,0 +1,93 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/splitter'
4
+ require 'rumale/preprocessing/label_encoder'
5
+
6
+ module Rumale
7
+ module ModelSelection
8
+ # GroupKFold is a class that generates the set of data indices for K-fold cross-validation.
9
+ # The data points belonging to the same group do not be split into different folds.
10
+ # The number of groups should be greater than or equal to the number of splits.
11
+ #
12
+ # @example
13
+ # cv = Rumale::ModelSelection::GroupKFold.new(n_splits: 3)
14
+ # x = Numo::DFloat.new(8, 2).rand
15
+ # groups = Numo::Int32[1, 1, 1, 2, 2, 3, 3, 3]
16
+ # cv.split(x, nil, groups).each do |train_ids, test_ids|
17
+ # puts '---'
18
+ # pp train_ids
19
+ # pp test_ids
20
+ # end
21
+ #
22
+ # # ---
23
+ # # [0, 1, 2, 3, 4]
24
+ # # [5, 6, 7]
25
+ # # ---
26
+ # # [3, 4, 5, 6, 7]
27
+ # # [0, 1, 2]
28
+ # # ---
29
+ # # [0, 1, 2, 5, 6, 7]
30
+ # # [3, 4]
31
+ #
32
+ class GroupKFold
33
+ include Base::Splitter
34
+
35
+ # Return the number of folds.
36
+ # @return [Integer]
37
+ attr_reader :n_splits
38
+
39
+ # Create a new data splitter for grouped K-fold cross validation.
40
+ #
41
+ # @param n_splits [Integer] The number of folds.
42
+ def initialize(n_splits: 5)
43
+ check_params_numeric(n_splits: n_splits)
44
+ @n_splits = n_splits
45
+ end
46
+
47
+ # Generate data indices for grouped K-fold cross validation.
48
+ #
49
+ # @overload split(x, y, groups) -> Array
50
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features])
51
+ # The dataset to be used to generate data indices for grouped K-fold cross validation.
52
+ # @param y [Numo::Int32] (shape: [n_samples])
53
+ # This argument exists to unify the interface between the K-fold methods, it is not used in the method.
54
+ # @param groups [Numo::Int32] (shape: [n_samples])
55
+ # The group labels to be used to generate data indices for grouped K-fold cross validation.
56
+ # @return [Array] The set of data indices for constructing the training and testing dataset in each fold.
57
+ def split(x, _y, groups)
58
+ x = check_convert_sample_array(x)
59
+ groups = check_convert_label_array(groups)
60
+ check_sample_label_size(x, groups)
61
+
62
+ encoder = Rumale::Preprocessing::LabelEncoder.new
63
+ groups = encoder.fit_transform(groups)
64
+ n_groups = encoder.classes.size
65
+
66
+ raise ArgumentError, 'The number of groups should be greater than or equal to the number of splits.' if n_groups < @n_splits
67
+
68
+ n_samples_per_group = groups.bincount
69
+ group_ids = n_samples_per_group.sort_index.reverse
70
+ n_samples_per_group = n_samples_per_group[group_ids]
71
+
72
+ n_samples_per_fold = Numo::Int32.zeros(@n_splits)
73
+ group_to_fold = Numo::Int32.zeros(n_groups)
74
+
75
+ n_samples_per_group.each_with_index do |weight, id|
76
+ min_sample_fold_id = n_samples_per_fold.min_index
77
+ n_samples_per_fold[min_sample_fold_id] += weight
78
+ group_to_fold[group_ids[id]] = min_sample_fold_id
79
+ end
80
+
81
+ n_samples = x.shape[0]
82
+ sample_ids = Array(0...n_samples)
83
+ fold_ids = group_to_fold[groups]
84
+
85
+ Array.new(@n_splits) do |fid|
86
+ test_ids = fold_ids.eq(fid).where.to_a
87
+ train_ids = sample_ids - test_ids
88
+ [train_ids, test_ids]
89
+ end
90
+ end
91
+ end
92
+ end
93
+ end
@@ -0,0 +1,115 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/splitter'
4
+
5
+ module Rumale
6
+ module ModelSelection
7
+ # GroupShuffleSplit is a class that generates the set of data indices
8
+ # for random permutation cross-validation by randomly selecting group labels.
9
+ #
10
+ # @example
11
+ # cv = Rumale::ModelSelection::GroupShuffleSplit.new(n_splits: 2, test_size: 0.2, random_seed: 1)
12
+ # x = Numo::DFloat.new(8, 2).rand
13
+ # groups = Numo::Int32[1, 1, 1, 2, 2, 3, 3, 3]
14
+ # cv.split(x, nil, groups).each do |train_ids, test_ids|
15
+ # puts '---'
16
+ # pp train_ids
17
+ # pp test_ids
18
+ # end
19
+ #
20
+ # # ---
21
+ # # [0, 1, 2, 5, 6, 7]
22
+ # # [3, 4]
23
+ # # ---
24
+ # # [3, 4, 5, 6, 7]
25
+ # # [0, 1, 2]
26
+ #
27
+ class GroupShuffleSplit
28
+ include Base::Splitter
29
+
30
+ # Return the number of folds.
31
+ # @return [Integer]
32
+ attr_reader :n_splits
33
+
34
+ # Return the random generator for shuffling the dataset.
35
+ # @return [Random]
36
+ attr_reader :rng
37
+
38
+ # Create a new data splitter for random permutation cross validation with given group labels.
39
+ #
40
+ # @param n_splits [Integer] The number of folds.
41
+ # @param test_size [Float] The ratio of number of groups for test data.
42
+ # @param train_size [Float/Nil] The ratio of number of groups for train data.
43
+ # @param random_seed [Integer] The seed value using to initialize the random generator.
44
+ def initialize(n_splits: 5, test_size: 0.2, train_size: nil, random_seed: nil)
45
+ check_params_numeric(n_splits: n_splits, test_size: test_size)
46
+ check_params_numeric_or_nil(train_size: train_size, random_seed: random_seed)
47
+ check_params_positive(n_splits: n_splits)
48
+ check_params_positive(test_size: test_size)
49
+ check_params_positive(train_size: train_size) unless train_size.nil?
50
+ @n_splits = n_splits
51
+ @test_size = test_size
52
+ @train_size = train_size
53
+ @random_seed = random_seed
54
+ @random_seed ||= srand
55
+ @rng = Random.new(@random_seed)
56
+ end
57
+
58
+ # Generate train and test data indices by randomly selecting group labels.
59
+ #
60
+ # @overload split(x, y, groups) -> Array
61
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features])
62
+ # The dataset to be used to generate data indices for random permutation cross validation.
63
+ # @param y [Numo::Int32] (shape: [n_samples])
64
+ # This argument exists to unify the interface between the K-fold methods, it is not used in the method.
65
+ # @param groups [Numo::Int32] (shape: [n_samples])
66
+ # The group labels to be used to generate data indices for random permutation cross validation.
67
+ # @return [Array] The set of data indices for constructing the training and testing dataset in each fold.
68
+ def split(x, _y, groups)
69
+ x = check_convert_sample_array(x)
70
+ groups = check_convert_label_array(groups)
71
+ check_sample_label_size(x, groups)
72
+
73
+ classes = groups.to_a.uniq.sort
74
+ n_groups = classes.size
75
+ n_test_groups = (@test_size * n_groups).ceil.to_i
76
+ n_train_groups = @train_size.nil? ? n_groups - n_test_groups : (@train_size * n_groups).floor.to_i
77
+
78
+ unless n_test_groups.between?(1, n_groups)
79
+ raise RangeError,
80
+ 'The number of groups in test split must be not less than 1 and not more than the number of groups.'
81
+ end
82
+ unless n_train_groups.between?(1, n_groups)
83
+ raise RangeError,
84
+ 'The number of groups in train split must be not less than 1 and not more than the number of groups.'
85
+ end
86
+ if (n_test_groups + n_train_groups) > n_groups
87
+ raise RangeError,
88
+ 'The total number of groups in test split and train split must be not more than the number of groups.'
89
+ end
90
+
91
+ sub_rng = @rng.dup
92
+
93
+ Array.new(@n_splits) do
94
+ test_group_ids = classes.sample(n_test_groups, random: sub_rng)
95
+ train_group_ids = if @train_size.nil?
96
+ classes - test_group_ids
97
+ else
98
+ (classes - test_group_ids).sample(n_train_groups, random: sub_rng)
99
+ end
100
+ test_ids = in1d(groups, test_group_ids).where.to_a
101
+ train_ids = in1d(groups, train_group_ids).where.to_a
102
+ [train_ids, test_ids]
103
+ end
104
+ end
105
+
106
+ private
107
+
108
+ def in1d(a, b)
109
+ res = Numo::Bit.zeros(a.shape[0])
110
+ b.each { |v| res |= a.eq(v) }
111
+ res
112
+ end
113
+ end
114
+ end
115
+ end
@@ -62,7 +62,7 @@ module Rumale
62
62
  end
63
63
  sub_rng = @rng.dup
64
64
  # Splits dataset ids to each fold.
65
- dataset_ids = [*0...n_samples]
65
+ dataset_ids = Array(0...n_samples)
66
66
  dataset_ids.shuffle!(random: sub_rng) if @shuffle
67
67
  fold_sets = Array.new(@n_splits) do |n|
68
68
  n_fold_samples = n_samples / @n_splits
@@ -54,19 +54,19 @@ module Rumale
54
54
  x = check_convert_sample_array(x)
55
55
  # Initialize and check some variables.
56
56
  n_samples = x.shape[0]
57
- n_test_samples = (@test_size * n_samples).to_i
58
- n_train_samples = @train_size.nil? ? n_samples - n_test_samples : (@train_size * n_samples).to_i
57
+ n_test_samples = (@test_size * n_samples).ceil.to_i
58
+ n_train_samples = @train_size.nil? ? n_samples - n_test_samples : (@train_size * n_samples).floor.to_i
59
59
  unless @n_splits.between?(1, n_samples)
60
60
  raise ArgumentError,
61
61
  'The value of n_splits must be not less than 1 and not more than the number of samples.'
62
62
  end
63
63
  unless n_test_samples.between?(1, n_samples)
64
64
  raise RangeError,
65
- 'The number of sample in test split must be not less than 1 and not more than the number of samples.'
65
+ 'The number of samples in test split must be not less than 1 and not more than the number of samples.'
66
66
  end
67
67
  unless n_train_samples.between?(1, n_samples)
68
68
  raise RangeError,
69
- 'The number of sample in train split must be not less than 1 and not more than the number of samples.'
69
+ 'The number of samples in train split must be not less than 1 and not more than the number of samples.'
70
70
  end
71
71
  if (n_test_samples + n_train_samples) > n_samples
72
72
  raise RangeError,
@@ -74,7 +74,7 @@ module Rumale
74
74
  end
75
75
  sub_rng = @rng.dup
76
76
  # Returns array consisting of the training and testing ids for each fold.
77
- dataset_ids = [*0...n_samples]
77
+ dataset_ids = Array(0...n_samples)
78
78
  Array.new(@n_splits) do
79
79
  test_ids = dataset_ids.sample(n_test_samples, random: sub_rng)
80
80
  train_ids = if @train_size.nil?
@@ -30,7 +30,7 @@ module Rumale
30
30
  # @return [Random]
31
31
  attr_reader :rng
32
32
 
33
- # Create a new data splitter for K-fold cross validation.
33
+ # Create a new data splitter for stratified K-fold cross validation.
34
34
  #
35
35
  # @param n_splits [Integer] The number of folds.
36
36
  # @param shuffle [Boolean] The flag indicating whether to shuffle the dataset.
@@ -66,15 +66,15 @@ module Rumale
66
66
  raise ArgumentError,
67
67
  'The value of n_splits must be not less than 1 and not more than the number of samples in each class.'
68
68
  end
69
- unless enough_data_size_each_class?(y, @test_size)
69
+ unless enough_data_size_each_class?(y, @test_size, 'test')
70
70
  raise RangeError,
71
- 'The number of sample in test split must be not less than 1 and not more than the number of samples in each class.'
71
+ 'The number of samples in test split must be not less than 1 and not more than the number of samples in each class.'
72
72
  end
73
- unless enough_data_size_each_class?(y, train_sz)
73
+ unless enough_data_size_each_class?(y, train_sz, 'train')
74
74
  raise RangeError,
75
- 'The number of sample in train split must be not less than 1 and not more than the number of samples in each class.'
75
+ 'The number of samples in train split must be not less than 1 and not more than the number of samples in each class.'
76
76
  end
77
- unless enough_data_size_each_class?(y, train_sz + @test_size)
77
+ unless enough_data_size_each_class?(y, train_sz + @test_size, 'train')
78
78
  raise RangeError,
79
79
  'The total number of samples in test split and train split must be not more than the number of samples in each class.'
80
80
  end
@@ -85,12 +85,12 @@ module Rumale
85
85
  test_ids = []
86
86
  sample_ids_each_class.each do |sample_ids|
87
87
  n_samples = sample_ids.size
88
- n_test_samples = (@test_size * n_samples).to_i
89
- n_train_samples = (train_sz * n_samples).to_i
88
+ n_test_samples = (@test_size * n_samples).ceil.to_i
90
89
  test_ids += sample_ids.sample(n_test_samples, random: sub_rng)
91
90
  train_ids += if @train_size.nil?
92
91
  sample_ids - test_ids
93
92
  else
93
+ n_train_samples = (train_sz * n_samples).floor.to_i
94
94
  (sample_ids - test_ids).sample(n_train_samples, random: sub_rng)
95
95
  end
96
96
  end
@@ -104,9 +104,13 @@ module Rumale
104
104
  y.to_a.uniq.map { |label| y.eq(label).where.size }.all? { |n_samples| @n_splits.between?(1, n_samples) }
105
105
  end
106
106
 
107
- def enough_data_size_each_class?(y, data_size)
107
+ def enough_data_size_each_class?(y, data_size, data_type)
108
108
  y.to_a.uniq.map { |label| y.eq(label).where.size }.all? do |n_samples|
109
- (data_size * n_samples).to_i.between?(1, n_samples)
109
+ if data_type == 'test'
110
+ (data_size * n_samples).ceil.to_i.between?(1, n_samples)
111
+ else
112
+ (data_size * n_samples).floor.to_i.between?(1, n_samples)
113
+ end
110
114
  end
111
115
  end
112
116
  end