rumale 0.20.0 → 0.20.1

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 358515f8785eb3de2e6571a957ca76cece6b774bb022c1a0951c92d44ab422b4
4
- data.tar.gz: '0289b7eb382cd3300845412af0fd43626f4f827bb719083c879b574e3ab37eb0'
3
+ metadata.gz: 0f361026cd2922a2d36846a817eee855bf0c000156ed6c756bca29d2e42d67a2
4
+ data.tar.gz: 016fa40aa2546824cacbc32353263cbfc9427f0ceabb7e703f99854914bb9a2e
5
5
  SHA512:
6
- metadata.gz: f03fc0f27f99ed4acea3fb7d7bf34017c1dbf923b20dabc9a78d6d44f0b151bc9dc78ba24d122f81607a43fd1852e398a603b75b87656a2f79109f87c0db0d98
7
- data.tar.gz: 69f6b8892f6bfb4c43706513245c3fba687dcb6a347c1c5185a70d5e45a024b2848a019bfae48726e1f49212878e8d6d67c811ec5f4a990fdbb3a2841efdfe9b
6
+ metadata.gz: 7a53a958db7ec8b56236018505370b9908ae81a9afc9d7c8ff0b16d83971539c1ad729b5ab350eb49ae9b90ada43a8912ed2404a37eef97a4d34dad90b1d3e9f
7
+ data.tar.gz: 2f2b3d48625c7120464179bc7759c01ba7de85cb0d54720665eaf1e4822f24c1870474ebc24a47cff123e44a8626b0e0fac6a7e81216c057286071770ea5ba79
@@ -3,6 +3,7 @@ require:
3
3
  - rubocop-rspec
4
4
 
5
5
  AllCops:
6
+ NewCops: enable
6
7
  TargetRubyVersion: 2.5
7
8
  DisplayCopNames: true
8
9
  DisplayStyleGuide: true
@@ -15,34 +16,13 @@ AllCops:
15
16
  Style/Documentation:
16
17
  Enabled: false
17
18
 
18
- Style/HashEachMethods:
19
- Enabled: true
20
-
21
- Style/HashTransformKeys:
22
- Enabled: true
23
-
24
- Style/HashTransformValues:
25
- Enabled: true
26
-
27
- Lint/DeprecatedOpenSSLConstant:
28
- Enabled: true
29
-
30
- Lint/DuplicateElsifCondition:
31
- Enabled: true
32
-
33
- Lint/MixedRegexpCaptureTypes:
34
- Enabled: true
35
-
36
- Lint/RaiseException:
37
- Enabled: true
38
-
39
- Lint/StructNewOverride:
40
- Enabled: true
41
-
42
19
  Layout/LineLength:
43
20
  Max: 145
44
21
  IgnoredPatterns: ['(\A|\s)#']
45
22
 
23
+ Lint/MissingSuper:
24
+ Enabled: false
25
+
46
26
  Metrics/ModuleLength:
47
27
  Max: 200
48
28
 
@@ -78,83 +58,14 @@ Naming/MethodParameterName:
78
58
  Naming/ConstantName:
79
59
  Enabled: false
80
60
 
81
- Style/AccessorGrouping:
82
- Enabled: true
83
-
84
- Style/ArrayCoercion:
85
- Enabled: true
86
-
87
- Style/BisectedAttrAccessor:
88
- Enabled: true
89
-
90
- Style/CaseLikeIf:
91
- Enabled: true
92
-
93
- Style/ExponentialNotation:
94
- Enabled: true
95
-
96
61
  Style/FormatStringToken:
97
62
  Enabled: false
98
63
 
99
- Style/HashAsLastArrayItem:
100
- Enabled: true
101
-
102
- Style/HashLikeCase:
103
- Enabled: true
104
-
105
64
  Style/NumericLiterals:
106
65
  Enabled: false
107
66
 
108
- Style/RedundantAssignment:
109
- Enabled: true
110
-
111
- Style/RedundantFetchBlock:
112
- Enabled: true
113
-
114
- Style/RedundantFileExtensionInRequire:
115
- Enabled: true
116
-
117
- Style/RedundantRegexpCharacterClass:
118
- Enabled: true
119
-
120
- Style/RedundantRegexpEscape:
121
- Enabled: true
122
-
123
- Style/SlicingWithRange:
124
- Enabled: true
125
-
126
- Layout/EmptyLineAfterGuardClause:
127
- Enabled: true
128
-
129
- Layout/EmptyLinesAroundAttributeAccessor:
130
- Enabled: true
131
-
132
- Layout/SpaceAroundMethodCallOperator:
133
- Enabled: true
134
-
135
- Performance/AncestorsInclude:
136
- Enabled: true
137
-
138
- Performance/BigDecimalWithNumericArgument:
139
- Enabled: true
140
-
141
- Performance/RedundantSortBlock:
142
- Enabled: true
143
-
144
- Performance/RedundantStringChars:
145
- Enabled: true
146
-
147
- Performance/ReverseFirst:
148
- Enabled: true
149
-
150
- Performance/SortReverse:
151
- Enabled: true
152
-
153
- Performance/Squeeze:
154
- Enabled: true
155
-
156
- Performance/StringInclude:
157
- Enabled: true
67
+ Style/StringConcatenation:
68
+ Enabled: false
158
69
 
159
70
  RSpec/MultipleExpectations:
160
71
  Enabled: false
@@ -1,3 +1,12 @@
1
+ # 0.20.1
2
+ - Add cross-validator classes that split data according group labels.
3
+ - [GroupKFold](https://yoshoku.github.io/rumale/doc/Rumale/ModelSelection/GroupKFold.html)
4
+ - [GroupShuffleSplit](https://yoshoku.github.io/rumale/doc/Rumale/ModelSelection/GroupShuffleSplit.html)
5
+ - Fix fraction treating of the number of samples on shuffle split cross-validator classes.
6
+ - [ShuffleSplit](https://yoshoku.github.io/rumale/doc/Rumale/ModelSelection/ShuffleSplit.html)
7
+ - [StratifiedShuffleSplit](https://yoshoku.github.io/rumale/doc/Rumale/ModelSelection/StratifiedShuffleSplit.html)
8
+ - Refactor some codes with Rubocop.
9
+
1
10
  # 0.20.0
2
11
  ## Breaking changes
3
12
  - Delete deprecated estimators such as PolynomialModel, Optimizer, and BaseLinearModel.
@@ -98,8 +98,10 @@ require 'rumale/preprocessing/ordinal_encoder'
98
98
  require 'rumale/preprocessing/binarizer'
99
99
  require 'rumale/preprocessing/polynomial_features'
100
100
  require 'rumale/model_selection/k_fold'
101
+ require 'rumale/model_selection/group_k_fold'
101
102
  require 'rumale/model_selection/stratified_k_fold'
102
103
  require 'rumale/model_selection/shuffle_split'
104
+ require 'rumale/model_selection/group_shuffle_split'
103
105
  require 'rumale/model_selection/stratified_shuffle_split'
104
106
  require 'rumale/model_selection/cross_validation'
105
107
  require 'rumale/model_selection/grid_search_cv'
@@ -136,7 +136,7 @@ module Rumale
136
136
  res
137
137
  end
138
138
 
139
- # rubocop:disable Metrics/AbcSize, Metrics/MethodLength, Metrics/PerceivedComplexity
139
+ # rubocop:disable Metrics/AbcSize, Metrics/CyclomaticComplexity, Metrics/MethodLength, Metrics/PerceivedComplexity
140
140
  def condense_tree(hierarchy, min_cluster_size)
141
141
  n_edges = hierarchy.size
142
142
  root = 2 * n_edges
@@ -265,7 +265,7 @@ module Rumale
265
265
  end
266
266
  res
267
267
  end
268
- # rubocop:enable Metrics/AbcSize, Metrics/MethodLength, Metrics/PerceivedComplexity
268
+ # rubocop:enable Metrics/AbcSize, Metrics/CyclomaticComplexity, Metrics/MethodLength, Metrics/PerceivedComplexity
269
269
  end
270
270
  end
271
271
  end
@@ -225,7 +225,7 @@ module Rumale
225
225
  line = dump_label(label, label_type.to_s)
226
226
  ftvec.to_a.each_with_index do |val, n|
227
227
  idx = n + (zero_based == false ? 1 : 0)
228
- line += format(" %d:#{value_type}", idx, val) if val != 0.0
228
+ line += format(" %d:#{value_type}", idx, val) if val != 0
229
229
  end
230
230
  line
231
231
  end
@@ -77,7 +77,7 @@ module Rumale
77
77
  # @return [Numo::DFloat] (shape: [n_samples, n_components]) The transformed data.
78
78
  def transform(x)
79
79
  x = check_convert_sample_array(x)
80
- partial_fit(x, false)
80
+ partial_fit(x, update_comps: false)
81
81
  end
82
82
 
83
83
  # Inverse transform the given transformed data with the learned model.
@@ -91,7 +91,7 @@ module Rumale
91
91
 
92
92
  private
93
93
 
94
- def partial_fit(x, update_comps = true)
94
+ def partial_fit(x, update_comps: true)
95
95
  # initialize some variables.
96
96
  n_samples, n_features = x.shape
97
97
  scale = Math.sqrt(x.mean / @params[:n_components])
@@ -85,7 +85,7 @@ module Rumale
85
85
  # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
86
86
  # @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
87
87
  # @return [RandomForestClassifier] The learned classifier itself.
88
- def fit(x, y)
88
+ def fit(x, y) # rubocop:disable Metrics/AbcSize
89
89
  x = check_convert_sample_array(x)
90
90
  y = check_convert_label_array(y)
91
91
  check_sample_label_size(x, y)
@@ -79,7 +79,7 @@ module Rumale
79
79
  # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
80
80
  # @param y [Numo::DFloat] (shape: [n_samples, n_outputs]) The target values to be used for fitting the model.
81
81
  # @return [RandomForestRegressor] The learned regressor itself.
82
- def fit(x, y)
82
+ def fit(x, y) # rubocop:disable Metrics/AbcSize
83
83
  x = check_convert_sample_array(x)
84
84
  y = check_convert_tvalue_array(y)
85
85
  check_sample_tvalue_size(x, y)
@@ -67,7 +67,7 @@ module Rumale
67
67
  def transform(x)
68
68
  raise 'FeatureHasher#transform requires Mmh3 but that is not loaded.' unless enable_mmh3?
69
69
 
70
- x = [x] unless x.is_a?(Array)
70
+ x = [x] unless x.is_a?(Array) # rubocop:disable Style/ArrayCoercion
71
71
  n_samples = x.size
72
72
 
73
73
  z = Numo::DFloat.zeros(n_samples, n_features)
@@ -99,7 +99,7 @@ module Rumale
99
99
  # @param x [Array<Hash>] (shape: [n_samples]) The array of hash consisting of feature names and values.
100
100
  # @return [Numo::DFloat] (shape: [n_samples, n_features]) The encoded sample array.
101
101
  def transform(x)
102
- x = [x] unless x.is_a?(Array)
102
+ x = [x] unless x.is_a?(Array) # rubocop:disable Style/ArrayCoercion
103
103
  n_samples = x.size
104
104
  n_features = @vocabulary.size
105
105
  z = Numo::DFloat.zeros(n_samples, n_features)
@@ -102,7 +102,7 @@ module Rumale
102
102
  break if terminate?(hi_prob_mat, lo_prob_mat)
103
103
 
104
104
  a = hi_prob_mat * lo_prob_mat
105
- b = lo_prob_mat * lo_prob_mat
105
+ b = lo_prob_mat**2
106
106
  y = (b.dot(one_vec) * y + (a - b).dot(y)) / a.dot(one_vec)
107
107
  lo_prob_mat = t_distributed_probability_matrix(y)
108
108
  @n_iter = t + 1
@@ -0,0 +1,93 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/splitter'
4
+ require 'rumale/preprocessing/label_encoder'
5
+
6
+ module Rumale
7
+ module ModelSelection
8
+ # GroupKFold is a class that generates the set of data indices for K-fold cross-validation.
9
+ # The data points belonging to the same group do not be split into different folds.
10
+ # The number of groups should be greater than or equal to the number of splits.
11
+ #
12
+ # @example
13
+ # cv = Rumale::ModelSelection::GroupKFold.new(n_splits: 3)
14
+ # x = Numo::DFloat.new(8, 2).rand
15
+ # groups = Numo::Int32[1, 1, 1, 2, 2, 3, 3, 3]
16
+ # cv.split(x, nil, groups).each do |train_ids, test_ids|
17
+ # puts '---'
18
+ # pp train_ids
19
+ # pp test_ids
20
+ # end
21
+ #
22
+ # # ---
23
+ # # [0, 1, 2, 3, 4]
24
+ # # [5, 6, 7]
25
+ # # ---
26
+ # # [3, 4, 5, 6, 7]
27
+ # # [0, 1, 2]
28
+ # # ---
29
+ # # [0, 1, 2, 5, 6, 7]
30
+ # # [3, 4]
31
+ #
32
+ class GroupKFold
33
+ include Base::Splitter
34
+
35
+ # Return the number of folds.
36
+ # @return [Integer]
37
+ attr_reader :n_splits
38
+
39
+ # Create a new data splitter for grouped K-fold cross validation.
40
+ #
41
+ # @param n_splits [Integer] The number of folds.
42
+ def initialize(n_splits: 5)
43
+ check_params_numeric(n_splits: n_splits)
44
+ @n_splits = n_splits
45
+ end
46
+
47
+ # Generate data indices for grouped K-fold cross validation.
48
+ #
49
+ # @overload split(x, y, groups) -> Array
50
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features])
51
+ # The dataset to be used to generate data indices for grouped K-fold cross validation.
52
+ # @param y [Numo::Int32] (shape: [n_samples])
53
+ # This argument exists to unify the interface between the K-fold methods, it is not used in the method.
54
+ # @param groups [Numo::Int32] (shape: [n_samples])
55
+ # The group labels to be used to generate data indices for grouped K-fold cross validation.
56
+ # @return [Array] The set of data indices for constructing the training and testing dataset in each fold.
57
+ def split(x, _y, groups)
58
+ x = check_convert_sample_array(x)
59
+ groups = check_convert_label_array(groups)
60
+ check_sample_label_size(x, groups)
61
+
62
+ encoder = Rumale::Preprocessing::LabelEncoder.new
63
+ groups = encoder.fit_transform(groups)
64
+ n_groups = encoder.classes.size
65
+
66
+ raise ArgumentError, 'The number of groups should be greater than or equal to the number of splits.' if n_groups < @n_splits
67
+
68
+ n_samples_per_group = groups.bincount
69
+ group_ids = n_samples_per_group.sort_index.reverse
70
+ n_samples_per_group = n_samples_per_group[group_ids]
71
+
72
+ n_samples_per_fold = Numo::Int32.zeros(@n_splits)
73
+ group_to_fold = Numo::Int32.zeros(n_groups)
74
+
75
+ n_samples_per_group.each_with_index do |weight, id|
76
+ min_sample_fold_id = n_samples_per_fold.min_index
77
+ n_samples_per_fold[min_sample_fold_id] += weight
78
+ group_to_fold[group_ids[id]] = min_sample_fold_id
79
+ end
80
+
81
+ n_samples = x.shape[0]
82
+ sample_ids = Array(0...n_samples)
83
+ fold_ids = group_to_fold[groups]
84
+
85
+ Array.new(@n_splits) do |fid|
86
+ test_ids = fold_ids.eq(fid).where.to_a
87
+ train_ids = sample_ids - test_ids
88
+ [train_ids, test_ids]
89
+ end
90
+ end
91
+ end
92
+ end
93
+ end
@@ -0,0 +1,115 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/splitter'
4
+
5
+ module Rumale
6
+ module ModelSelection
7
+ # GroupShuffleSplit is a class that generates the set of data indices
8
+ # for random permutation cross-validation by randomly selecting group labels.
9
+ #
10
+ # @example
11
+ # cv = Rumale::ModelSelection::GroupShuffleSplit.new(n_splits: 2, test_size: 0.2, random_seed: 1)
12
+ # x = Numo::DFloat.new(8, 2).rand
13
+ # groups = Numo::Int32[1, 1, 1, 2, 2, 3, 3, 3]
14
+ # cv.split(x, nil, groups).each do |train_ids, test_ids|
15
+ # puts '---'
16
+ # pp train_ids
17
+ # pp test_ids
18
+ # end
19
+ #
20
+ # # ---
21
+ # # [0, 1, 2, 5, 6, 7]
22
+ # # [3, 4]
23
+ # # ---
24
+ # # [3, 4, 5, 6, 7]
25
+ # # [0, 1, 2]
26
+ #
27
+ class GroupShuffleSplit
28
+ include Base::Splitter
29
+
30
+ # Return the number of folds.
31
+ # @return [Integer]
32
+ attr_reader :n_splits
33
+
34
+ # Return the random generator for shuffling the dataset.
35
+ # @return [Random]
36
+ attr_reader :rng
37
+
38
+ # Create a new data splitter for random permutation cross validation with given group labels.
39
+ #
40
+ # @param n_splits [Integer] The number of folds.
41
+ # @param test_size [Float] The ratio of number of groups for test data.
42
+ # @param train_size [Float/Nil] The ratio of number of groups for train data.
43
+ # @param random_seed [Integer] The seed value using to initialize the random generator.
44
+ def initialize(n_splits: 5, test_size: 0.2, train_size: nil, random_seed: nil)
45
+ check_params_numeric(n_splits: n_splits, test_size: test_size)
46
+ check_params_numeric_or_nil(train_size: train_size, random_seed: random_seed)
47
+ check_params_positive(n_splits: n_splits)
48
+ check_params_positive(test_size: test_size)
49
+ check_params_positive(train_size: train_size) unless train_size.nil?
50
+ @n_splits = n_splits
51
+ @test_size = test_size
52
+ @train_size = train_size
53
+ @random_seed = random_seed
54
+ @random_seed ||= srand
55
+ @rng = Random.new(@random_seed)
56
+ end
57
+
58
+ # Generate train and test data indices by randomly selecting group labels.
59
+ #
60
+ # @overload split(x, y, groups) -> Array
61
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features])
62
+ # The dataset to be used to generate data indices for random permutation cross validation.
63
+ # @param y [Numo::Int32] (shape: [n_samples])
64
+ # This argument exists to unify the interface between the K-fold methods, it is not used in the method.
65
+ # @param groups [Numo::Int32] (shape: [n_samples])
66
+ # The group labels to be used to generate data indices for random permutation cross validation.
67
+ # @return [Array] The set of data indices for constructing the training and testing dataset in each fold.
68
+ def split(x, _y, groups)
69
+ x = check_convert_sample_array(x)
70
+ groups = check_convert_label_array(groups)
71
+ check_sample_label_size(x, groups)
72
+
73
+ classes = groups.to_a.uniq.sort
74
+ n_groups = classes.size
75
+ n_test_groups = (@test_size * n_groups).ceil.to_i
76
+ n_train_groups = @train_size.nil? ? n_groups - n_test_groups : (@train_size * n_groups).floor.to_i
77
+
78
+ unless n_test_groups.between?(1, n_groups)
79
+ raise RangeError,
80
+ 'The number of groups in test split must be not less than 1 and not more than the number of groups.'
81
+ end
82
+ unless n_train_groups.between?(1, n_groups)
83
+ raise RangeError,
84
+ 'The number of groups in train split must be not less than 1 and not more than the number of groups.'
85
+ end
86
+ if (n_test_groups + n_train_groups) > n_groups
87
+ raise RangeError,
88
+ 'The total number of groups in test split and train split must be not more than the number of groups.'
89
+ end
90
+
91
+ sub_rng = @rng.dup
92
+
93
+ Array.new(@n_splits) do
94
+ test_group_ids = classes.sample(n_test_groups, random: sub_rng)
95
+ train_group_ids = if @train_size.nil?
96
+ classes - test_group_ids
97
+ else
98
+ (classes - test_group_ids).sample(n_train_groups, random: sub_rng)
99
+ end
100
+ test_ids = in1d(groups, test_group_ids).where.to_a
101
+ train_ids = in1d(groups, train_group_ids).where.to_a
102
+ [train_ids, test_ids]
103
+ end
104
+ end
105
+
106
+ private
107
+
108
+ def in1d(a, b)
109
+ res = Numo::Bit.zeros(a.shape[0])
110
+ b.each { |v| res |= a.eq(v) }
111
+ res
112
+ end
113
+ end
114
+ end
115
+ end
@@ -54,19 +54,19 @@ module Rumale
54
54
  x = check_convert_sample_array(x)
55
55
  # Initialize and check some variables.
56
56
  n_samples = x.shape[0]
57
- n_test_samples = (@test_size * n_samples).to_i
58
- n_train_samples = @train_size.nil? ? n_samples - n_test_samples : (@train_size * n_samples).to_i
57
+ n_test_samples = (@test_size * n_samples).ceil.to_i
58
+ n_train_samples = @train_size.nil? ? n_samples - n_test_samples : (@train_size * n_samples).floor.to_i
59
59
  unless @n_splits.between?(1, n_samples)
60
60
  raise ArgumentError,
61
61
  'The value of n_splits must be not less than 1 and not more than the number of samples.'
62
62
  end
63
63
  unless n_test_samples.between?(1, n_samples)
64
64
  raise RangeError,
65
- 'The number of sample in test split must be not less than 1 and not more than the number of samples.'
65
+ 'The number of samples in test split must be not less than 1 and not more than the number of samples.'
66
66
  end
67
67
  unless n_train_samples.between?(1, n_samples)
68
68
  raise RangeError,
69
- 'The number of sample in train split must be not less than 1 and not more than the number of samples.'
69
+ 'The number of samples in train split must be not less than 1 and not more than the number of samples.'
70
70
  end
71
71
  if (n_test_samples + n_train_samples) > n_samples
72
72
  raise RangeError,
@@ -30,7 +30,7 @@ module Rumale
30
30
  # @return [Random]
31
31
  attr_reader :rng
32
32
 
33
- # Create a new data splitter for K-fold cross validation.
33
+ # Create a new data splitter for stratified K-fold cross validation.
34
34
  #
35
35
  # @param n_splits [Integer] The number of folds.
36
36
  # @param shuffle [Boolean] The flag indicating whether to shuffle the dataset.
@@ -66,15 +66,15 @@ module Rumale
66
66
  raise ArgumentError,
67
67
  'The value of n_splits must be not less than 1 and not more than the number of samples in each class.'
68
68
  end
69
- unless enough_data_size_each_class?(y, @test_size)
69
+ unless enough_data_size_each_class?(y, @test_size, 'test')
70
70
  raise RangeError,
71
- 'The number of sample in test split must be not less than 1 and not more than the number of samples in each class.'
71
+ 'The number of samples in test split must be not less than 1 and not more than the number of samples in each class.'
72
72
  end
73
- unless enough_data_size_each_class?(y, train_sz)
73
+ unless enough_data_size_each_class?(y, train_sz, 'train')
74
74
  raise RangeError,
75
- 'The number of sample in train split must be not less than 1 and not more than the number of samples in each class.'
75
+ 'The number of samples in train split must be not less than 1 and not more than the number of samples in each class.'
76
76
  end
77
- unless enough_data_size_each_class?(y, train_sz + @test_size)
77
+ unless enough_data_size_each_class?(y, train_sz + @test_size, 'train')
78
78
  raise RangeError,
79
79
  'The total number of samples in test split and train split must be not more than the number of samples in each class.'
80
80
  end
@@ -85,12 +85,12 @@ module Rumale
85
85
  test_ids = []
86
86
  sample_ids_each_class.each do |sample_ids|
87
87
  n_samples = sample_ids.size
88
- n_test_samples = (@test_size * n_samples).to_i
89
- n_train_samples = (train_sz * n_samples).to_i
88
+ n_test_samples = (@test_size * n_samples).ceil.to_i
90
89
  test_ids += sample_ids.sample(n_test_samples, random: sub_rng)
91
90
  train_ids += if @train_size.nil?
92
91
  sample_ids - test_ids
93
92
  else
93
+ n_train_samples = (train_sz * n_samples).floor.to_i
94
94
  (sample_ids - test_ids).sample(n_train_samples, random: sub_rng)
95
95
  end
96
96
  end
@@ -104,9 +104,13 @@ module Rumale
104
104
  y.to_a.uniq.map { |label| y.eq(label).where.size }.all? { |n_samples| @n_splits.between?(1, n_samples) }
105
105
  end
106
106
 
107
- def enough_data_size_each_class?(y, data_size)
107
+ def enough_data_size_each_class?(y, data_size, data_type)
108
108
  y.to_a.uniq.map { |label| y.eq(label).where.size }.all? do |n_samples|
109
- (data_size * n_samples).to_i.between?(1, n_samples)
109
+ if data_type == 'test'
110
+ (data_size * n_samples).ceil.to_i.between?(1, n_samples)
111
+ else
112
+ (data_size * n_samples).floor.to_i.between?(1, n_samples)
113
+ end
110
114
  end
111
115
  end
112
116
  end
@@ -98,7 +98,7 @@ module Rumale
98
98
 
99
99
  def hessian_matrix(probs, df, sigma)
100
100
  sub = probs * (1 - probs)
101
- h11 = (df * df * sub).sum + sigma
101
+ h11 = (df**2 * sub).sum + sigma
102
102
  h22 = sub.sum + sigma
103
103
  h21 = (df * sub).sum
104
104
  Numo::DFloat[[h11, h21], [h21, h22]]
@@ -3,5 +3,5 @@
3
3
  # Rumale is a machine learning library in Ruby.
4
4
  module Rumale
5
5
  # The version of Rumale you are using.
6
- VERSION = '0.20.0'
6
+ VERSION = '0.20.1'
7
7
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: rumale
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.20.0
4
+ version: 0.20.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2020-08-01 00:00:00.000000000 Z
11
+ date: 2020-08-23 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -135,6 +135,8 @@ files:
135
135
  - lib/rumale/model_selection/cross_validation.rb
136
136
  - lib/rumale/model_selection/function.rb
137
137
  - lib/rumale/model_selection/grid_search_cv.rb
138
+ - lib/rumale/model_selection/group_k_fold.rb
139
+ - lib/rumale/model_selection/group_shuffle_split.rb
138
140
  - lib/rumale/model_selection/k_fold.rb
139
141
  - lib/rumale/model_selection/shuffle_split.rb
140
142
  - lib/rumale/model_selection/stratified_k_fold.rb