rumale 0.20.0 → 0.20.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 358515f8785eb3de2e6571a957ca76cece6b774bb022c1a0951c92d44ab422b4
4
- data.tar.gz: '0289b7eb382cd3300845412af0fd43626f4f827bb719083c879b574e3ab37eb0'
3
+ metadata.gz: 0f361026cd2922a2d36846a817eee855bf0c000156ed6c756bca29d2e42d67a2
4
+ data.tar.gz: 016fa40aa2546824cacbc32353263cbfc9427f0ceabb7e703f99854914bb9a2e
5
5
  SHA512:
6
- metadata.gz: f03fc0f27f99ed4acea3fb7d7bf34017c1dbf923b20dabc9a78d6d44f0b151bc9dc78ba24d122f81607a43fd1852e398a603b75b87656a2f79109f87c0db0d98
7
- data.tar.gz: 69f6b8892f6bfb4c43706513245c3fba687dcb6a347c1c5185a70d5e45a024b2848a019bfae48726e1f49212878e8d6d67c811ec5f4a990fdbb3a2841efdfe9b
6
+ metadata.gz: 7a53a958db7ec8b56236018505370b9908ae81a9afc9d7c8ff0b16d83971539c1ad729b5ab350eb49ae9b90ada43a8912ed2404a37eef97a4d34dad90b1d3e9f
7
+ data.tar.gz: 2f2b3d48625c7120464179bc7759c01ba7de85cb0d54720665eaf1e4822f24c1870474ebc24a47cff123e44a8626b0e0fac6a7e81216c057286071770ea5ba79
@@ -3,6 +3,7 @@ require:
3
3
  - rubocop-rspec
4
4
 
5
5
  AllCops:
6
+ NewCops: enable
6
7
  TargetRubyVersion: 2.5
7
8
  DisplayCopNames: true
8
9
  DisplayStyleGuide: true
@@ -15,34 +16,13 @@ AllCops:
15
16
  Style/Documentation:
16
17
  Enabled: false
17
18
 
18
- Style/HashEachMethods:
19
- Enabled: true
20
-
21
- Style/HashTransformKeys:
22
- Enabled: true
23
-
24
- Style/HashTransformValues:
25
- Enabled: true
26
-
27
- Lint/DeprecatedOpenSSLConstant:
28
- Enabled: true
29
-
30
- Lint/DuplicateElsifCondition:
31
- Enabled: true
32
-
33
- Lint/MixedRegexpCaptureTypes:
34
- Enabled: true
35
-
36
- Lint/RaiseException:
37
- Enabled: true
38
-
39
- Lint/StructNewOverride:
40
- Enabled: true
41
-
42
19
  Layout/LineLength:
43
20
  Max: 145
44
21
  IgnoredPatterns: ['(\A|\s)#']
45
22
 
23
+ Lint/MissingSuper:
24
+ Enabled: false
25
+
46
26
  Metrics/ModuleLength:
47
27
  Max: 200
48
28
 
@@ -78,83 +58,14 @@ Naming/MethodParameterName:
78
58
  Naming/ConstantName:
79
59
  Enabled: false
80
60
 
81
- Style/AccessorGrouping:
82
- Enabled: true
83
-
84
- Style/ArrayCoercion:
85
- Enabled: true
86
-
87
- Style/BisectedAttrAccessor:
88
- Enabled: true
89
-
90
- Style/CaseLikeIf:
91
- Enabled: true
92
-
93
- Style/ExponentialNotation:
94
- Enabled: true
95
-
96
61
  Style/FormatStringToken:
97
62
  Enabled: false
98
63
 
99
- Style/HashAsLastArrayItem:
100
- Enabled: true
101
-
102
- Style/HashLikeCase:
103
- Enabled: true
104
-
105
64
  Style/NumericLiterals:
106
65
  Enabled: false
107
66
 
108
- Style/RedundantAssignment:
109
- Enabled: true
110
-
111
- Style/RedundantFetchBlock:
112
- Enabled: true
113
-
114
- Style/RedundantFileExtensionInRequire:
115
- Enabled: true
116
-
117
- Style/RedundantRegexpCharacterClass:
118
- Enabled: true
119
-
120
- Style/RedundantRegexpEscape:
121
- Enabled: true
122
-
123
- Style/SlicingWithRange:
124
- Enabled: true
125
-
126
- Layout/EmptyLineAfterGuardClause:
127
- Enabled: true
128
-
129
- Layout/EmptyLinesAroundAttributeAccessor:
130
- Enabled: true
131
-
132
- Layout/SpaceAroundMethodCallOperator:
133
- Enabled: true
134
-
135
- Performance/AncestorsInclude:
136
- Enabled: true
137
-
138
- Performance/BigDecimalWithNumericArgument:
139
- Enabled: true
140
-
141
- Performance/RedundantSortBlock:
142
- Enabled: true
143
-
144
- Performance/RedundantStringChars:
145
- Enabled: true
146
-
147
- Performance/ReverseFirst:
148
- Enabled: true
149
-
150
- Performance/SortReverse:
151
- Enabled: true
152
-
153
- Performance/Squeeze:
154
- Enabled: true
155
-
156
- Performance/StringInclude:
157
- Enabled: true
67
+ Style/StringConcatenation:
68
+ Enabled: false
158
69
 
159
70
  RSpec/MultipleExpectations:
160
71
  Enabled: false
@@ -1,3 +1,12 @@
1
+ # 0.20.1
2
+ - Add cross-validator classes that split data according group labels.
3
+ - [GroupKFold](https://yoshoku.github.io/rumale/doc/Rumale/ModelSelection/GroupKFold.html)
4
+ - [GroupShuffleSplit](https://yoshoku.github.io/rumale/doc/Rumale/ModelSelection/GroupShuffleSplit.html)
5
+ - Fix fraction treating of the number of samples on shuffle split cross-validator classes.
6
+ - [ShuffleSplit](https://yoshoku.github.io/rumale/doc/Rumale/ModelSelection/ShuffleSplit.html)
7
+ - [StratifiedShuffleSplit](https://yoshoku.github.io/rumale/doc/Rumale/ModelSelection/StratifiedShuffleSplit.html)
8
+ - Refactor some codes with Rubocop.
9
+
1
10
  # 0.20.0
2
11
  ## Breaking changes
3
12
  - Delete deprecated estimators such as PolynomialModel, Optimizer, and BaseLinearModel.
@@ -98,8 +98,10 @@ require 'rumale/preprocessing/ordinal_encoder'
98
98
  require 'rumale/preprocessing/binarizer'
99
99
  require 'rumale/preprocessing/polynomial_features'
100
100
  require 'rumale/model_selection/k_fold'
101
+ require 'rumale/model_selection/group_k_fold'
101
102
  require 'rumale/model_selection/stratified_k_fold'
102
103
  require 'rumale/model_selection/shuffle_split'
104
+ require 'rumale/model_selection/group_shuffle_split'
103
105
  require 'rumale/model_selection/stratified_shuffle_split'
104
106
  require 'rumale/model_selection/cross_validation'
105
107
  require 'rumale/model_selection/grid_search_cv'
@@ -136,7 +136,7 @@ module Rumale
136
136
  res
137
137
  end
138
138
 
139
- # rubocop:disable Metrics/AbcSize, Metrics/MethodLength, Metrics/PerceivedComplexity
139
+ # rubocop:disable Metrics/AbcSize, Metrics/CyclomaticComplexity, Metrics/MethodLength, Metrics/PerceivedComplexity
140
140
  def condense_tree(hierarchy, min_cluster_size)
141
141
  n_edges = hierarchy.size
142
142
  root = 2 * n_edges
@@ -265,7 +265,7 @@ module Rumale
265
265
  end
266
266
  res
267
267
  end
268
- # rubocop:enable Metrics/AbcSize, Metrics/MethodLength, Metrics/PerceivedComplexity
268
+ # rubocop:enable Metrics/AbcSize, Metrics/CyclomaticComplexity, Metrics/MethodLength, Metrics/PerceivedComplexity
269
269
  end
270
270
  end
271
271
  end
@@ -225,7 +225,7 @@ module Rumale
225
225
  line = dump_label(label, label_type.to_s)
226
226
  ftvec.to_a.each_with_index do |val, n|
227
227
  idx = n + (zero_based == false ? 1 : 0)
228
- line += format(" %d:#{value_type}", idx, val) if val != 0.0
228
+ line += format(" %d:#{value_type}", idx, val) if val != 0
229
229
  end
230
230
  line
231
231
  end
@@ -77,7 +77,7 @@ module Rumale
77
77
  # @return [Numo::DFloat] (shape: [n_samples, n_components]) The transformed data.
78
78
  def transform(x)
79
79
  x = check_convert_sample_array(x)
80
- partial_fit(x, false)
80
+ partial_fit(x, update_comps: false)
81
81
  end
82
82
 
83
83
  # Inverse transform the given transformed data with the learned model.
@@ -91,7 +91,7 @@ module Rumale
91
91
 
92
92
  private
93
93
 
94
- def partial_fit(x, update_comps = true)
94
+ def partial_fit(x, update_comps: true)
95
95
  # initialize some variables.
96
96
  n_samples, n_features = x.shape
97
97
  scale = Math.sqrt(x.mean / @params[:n_components])
@@ -85,7 +85,7 @@ module Rumale
85
85
  # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
86
86
  # @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
87
87
  # @return [RandomForestClassifier] The learned classifier itself.
88
- def fit(x, y)
88
+ def fit(x, y) # rubocop:disable Metrics/AbcSize
89
89
  x = check_convert_sample_array(x)
90
90
  y = check_convert_label_array(y)
91
91
  check_sample_label_size(x, y)
@@ -79,7 +79,7 @@ module Rumale
79
79
  # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
80
80
  # @param y [Numo::DFloat] (shape: [n_samples, n_outputs]) The target values to be used for fitting the model.
81
81
  # @return [RandomForestRegressor] The learned regressor itself.
82
- def fit(x, y)
82
+ def fit(x, y) # rubocop:disable Metrics/AbcSize
83
83
  x = check_convert_sample_array(x)
84
84
  y = check_convert_tvalue_array(y)
85
85
  check_sample_tvalue_size(x, y)
@@ -67,7 +67,7 @@ module Rumale
67
67
  def transform(x)
68
68
  raise 'FeatureHasher#transform requires Mmh3 but that is not loaded.' unless enable_mmh3?
69
69
 
70
- x = [x] unless x.is_a?(Array)
70
+ x = [x] unless x.is_a?(Array) # rubocop:disable Style/ArrayCoercion
71
71
  n_samples = x.size
72
72
 
73
73
  z = Numo::DFloat.zeros(n_samples, n_features)
@@ -99,7 +99,7 @@ module Rumale
99
99
  # @param x [Array<Hash>] (shape: [n_samples]) The array of hash consisting of feature names and values.
100
100
  # @return [Numo::DFloat] (shape: [n_samples, n_features]) The encoded sample array.
101
101
  def transform(x)
102
- x = [x] unless x.is_a?(Array)
102
+ x = [x] unless x.is_a?(Array) # rubocop:disable Style/ArrayCoercion
103
103
  n_samples = x.size
104
104
  n_features = @vocabulary.size
105
105
  z = Numo::DFloat.zeros(n_samples, n_features)
@@ -102,7 +102,7 @@ module Rumale
102
102
  break if terminate?(hi_prob_mat, lo_prob_mat)
103
103
 
104
104
  a = hi_prob_mat * lo_prob_mat
105
- b = lo_prob_mat * lo_prob_mat
105
+ b = lo_prob_mat**2
106
106
  y = (b.dot(one_vec) * y + (a - b).dot(y)) / a.dot(one_vec)
107
107
  lo_prob_mat = t_distributed_probability_matrix(y)
108
108
  @n_iter = t + 1
@@ -0,0 +1,93 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/splitter'
4
+ require 'rumale/preprocessing/label_encoder'
5
+
6
+ module Rumale
7
+ module ModelSelection
8
+ # GroupKFold is a class that generates the set of data indices for K-fold cross-validation.
9
+ # The data points belonging to the same group do not be split into different folds.
10
+ # The number of groups should be greater than or equal to the number of splits.
11
+ #
12
+ # @example
13
+ # cv = Rumale::ModelSelection::GroupKFold.new(n_splits: 3)
14
+ # x = Numo::DFloat.new(8, 2).rand
15
+ # groups = Numo::Int32[1, 1, 1, 2, 2, 3, 3, 3]
16
+ # cv.split(x, nil, groups).each do |train_ids, test_ids|
17
+ # puts '---'
18
+ # pp train_ids
19
+ # pp test_ids
20
+ # end
21
+ #
22
+ # # ---
23
+ # # [0, 1, 2, 3, 4]
24
+ # # [5, 6, 7]
25
+ # # ---
26
+ # # [3, 4, 5, 6, 7]
27
+ # # [0, 1, 2]
28
+ # # ---
29
+ # # [0, 1, 2, 5, 6, 7]
30
+ # # [3, 4]
31
+ #
32
+ class GroupKFold
33
+ include Base::Splitter
34
+
35
+ # Return the number of folds.
36
+ # @return [Integer]
37
+ attr_reader :n_splits
38
+
39
+ # Create a new data splitter for grouped K-fold cross validation.
40
+ #
41
+ # @param n_splits [Integer] The number of folds.
42
+ def initialize(n_splits: 5)
43
+ check_params_numeric(n_splits: n_splits)
44
+ @n_splits = n_splits
45
+ end
46
+
47
+ # Generate data indices for grouped K-fold cross validation.
48
+ #
49
+ # @overload split(x, y, groups) -> Array
50
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features])
51
+ # The dataset to be used to generate data indices for grouped K-fold cross validation.
52
+ # @param y [Numo::Int32] (shape: [n_samples])
53
+ # This argument exists to unify the interface between the K-fold methods, it is not used in the method.
54
+ # @param groups [Numo::Int32] (shape: [n_samples])
55
+ # The group labels to be used to generate data indices for grouped K-fold cross validation.
56
+ # @return [Array] The set of data indices for constructing the training and testing dataset in each fold.
57
+ def split(x, _y, groups)
58
+ x = check_convert_sample_array(x)
59
+ groups = check_convert_label_array(groups)
60
+ check_sample_label_size(x, groups)
61
+
62
+ encoder = Rumale::Preprocessing::LabelEncoder.new
63
+ groups = encoder.fit_transform(groups)
64
+ n_groups = encoder.classes.size
65
+
66
+ raise ArgumentError, 'The number of groups should be greater than or equal to the number of splits.' if n_groups < @n_splits
67
+
68
+ n_samples_per_group = groups.bincount
69
+ group_ids = n_samples_per_group.sort_index.reverse
70
+ n_samples_per_group = n_samples_per_group[group_ids]
71
+
72
+ n_samples_per_fold = Numo::Int32.zeros(@n_splits)
73
+ group_to_fold = Numo::Int32.zeros(n_groups)
74
+
75
+ n_samples_per_group.each_with_index do |weight, id|
76
+ min_sample_fold_id = n_samples_per_fold.min_index
77
+ n_samples_per_fold[min_sample_fold_id] += weight
78
+ group_to_fold[group_ids[id]] = min_sample_fold_id
79
+ end
80
+
81
+ n_samples = x.shape[0]
82
+ sample_ids = Array(0...n_samples)
83
+ fold_ids = group_to_fold[groups]
84
+
85
+ Array.new(@n_splits) do |fid|
86
+ test_ids = fold_ids.eq(fid).where.to_a
87
+ train_ids = sample_ids - test_ids
88
+ [train_ids, test_ids]
89
+ end
90
+ end
91
+ end
92
+ end
93
+ end
@@ -0,0 +1,115 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/splitter'
4
+
5
+ module Rumale
6
+ module ModelSelection
7
+ # GroupShuffleSplit is a class that generates the set of data indices
8
+ # for random permutation cross-validation by randomly selecting group labels.
9
+ #
10
+ # @example
11
+ # cv = Rumale::ModelSelection::GroupShuffleSplit.new(n_splits: 2, test_size: 0.2, random_seed: 1)
12
+ # x = Numo::DFloat.new(8, 2).rand
13
+ # groups = Numo::Int32[1, 1, 1, 2, 2, 3, 3, 3]
14
+ # cv.split(x, nil, groups).each do |train_ids, test_ids|
15
+ # puts '---'
16
+ # pp train_ids
17
+ # pp test_ids
18
+ # end
19
+ #
20
+ # # ---
21
+ # # [0, 1, 2, 5, 6, 7]
22
+ # # [3, 4]
23
+ # # ---
24
+ # # [3, 4, 5, 6, 7]
25
+ # # [0, 1, 2]
26
+ #
27
+ class GroupShuffleSplit
28
+ include Base::Splitter
29
+
30
+ # Return the number of folds.
31
+ # @return [Integer]
32
+ attr_reader :n_splits
33
+
34
+ # Return the random generator for shuffling the dataset.
35
+ # @return [Random]
36
+ attr_reader :rng
37
+
38
+ # Create a new data splitter for random permutation cross validation with given group labels.
39
+ #
40
+ # @param n_splits [Integer] The number of folds.
41
+ # @param test_size [Float] The ratio of number of groups for test data.
42
+ # @param train_size [Float/Nil] The ratio of number of groups for train data.
43
+ # @param random_seed [Integer] The seed value using to initialize the random generator.
44
+ def initialize(n_splits: 5, test_size: 0.2, train_size: nil, random_seed: nil)
45
+ check_params_numeric(n_splits: n_splits, test_size: test_size)
46
+ check_params_numeric_or_nil(train_size: train_size, random_seed: random_seed)
47
+ check_params_positive(n_splits: n_splits)
48
+ check_params_positive(test_size: test_size)
49
+ check_params_positive(train_size: train_size) unless train_size.nil?
50
+ @n_splits = n_splits
51
+ @test_size = test_size
52
+ @train_size = train_size
53
+ @random_seed = random_seed
54
+ @random_seed ||= srand
55
+ @rng = Random.new(@random_seed)
56
+ end
57
+
58
+ # Generate train and test data indices by randomly selecting group labels.
59
+ #
60
+ # @overload split(x, y, groups) -> Array
61
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features])
62
+ # The dataset to be used to generate data indices for random permutation cross validation.
63
+ # @param y [Numo::Int32] (shape: [n_samples])
64
+ # This argument exists to unify the interface between the K-fold methods, it is not used in the method.
65
+ # @param groups [Numo::Int32] (shape: [n_samples])
66
+ # The group labels to be used to generate data indices for random permutation cross validation.
67
+ # @return [Array] The set of data indices for constructing the training and testing dataset in each fold.
68
+ def split(x, _y, groups)
69
+ x = check_convert_sample_array(x)
70
+ groups = check_convert_label_array(groups)
71
+ check_sample_label_size(x, groups)
72
+
73
+ classes = groups.to_a.uniq.sort
74
+ n_groups = classes.size
75
+ n_test_groups = (@test_size * n_groups).ceil.to_i
76
+ n_train_groups = @train_size.nil? ? n_groups - n_test_groups : (@train_size * n_groups).floor.to_i
77
+
78
+ unless n_test_groups.between?(1, n_groups)
79
+ raise RangeError,
80
+ 'The number of groups in test split must be not less than 1 and not more than the number of groups.'
81
+ end
82
+ unless n_train_groups.between?(1, n_groups)
83
+ raise RangeError,
84
+ 'The number of groups in train split must be not less than 1 and not more than the number of groups.'
85
+ end
86
+ if (n_test_groups + n_train_groups) > n_groups
87
+ raise RangeError,
88
+ 'The total number of groups in test split and train split must be not more than the number of groups.'
89
+ end
90
+
91
+ sub_rng = @rng.dup
92
+
93
+ Array.new(@n_splits) do
94
+ test_group_ids = classes.sample(n_test_groups, random: sub_rng)
95
+ train_group_ids = if @train_size.nil?
96
+ classes - test_group_ids
97
+ else
98
+ (classes - test_group_ids).sample(n_train_groups, random: sub_rng)
99
+ end
100
+ test_ids = in1d(groups, test_group_ids).where.to_a
101
+ train_ids = in1d(groups, train_group_ids).where.to_a
102
+ [train_ids, test_ids]
103
+ end
104
+ end
105
+
106
+ private
107
+
108
+ def in1d(a, b)
109
+ res = Numo::Bit.zeros(a.shape[0])
110
+ b.each { |v| res |= a.eq(v) }
111
+ res
112
+ end
113
+ end
114
+ end
115
+ end
@@ -54,19 +54,19 @@ module Rumale
54
54
  x = check_convert_sample_array(x)
55
55
  # Initialize and check some variables.
56
56
  n_samples = x.shape[0]
57
- n_test_samples = (@test_size * n_samples).to_i
58
- n_train_samples = @train_size.nil? ? n_samples - n_test_samples : (@train_size * n_samples).to_i
57
+ n_test_samples = (@test_size * n_samples).ceil.to_i
58
+ n_train_samples = @train_size.nil? ? n_samples - n_test_samples : (@train_size * n_samples).floor.to_i
59
59
  unless @n_splits.between?(1, n_samples)
60
60
  raise ArgumentError,
61
61
  'The value of n_splits must be not less than 1 and not more than the number of samples.'
62
62
  end
63
63
  unless n_test_samples.between?(1, n_samples)
64
64
  raise RangeError,
65
- 'The number of sample in test split must be not less than 1 and not more than the number of samples.'
65
+ 'The number of samples in test split must be not less than 1 and not more than the number of samples.'
66
66
  end
67
67
  unless n_train_samples.between?(1, n_samples)
68
68
  raise RangeError,
69
- 'The number of sample in train split must be not less than 1 and not more than the number of samples.'
69
+ 'The number of samples in train split must be not less than 1 and not more than the number of samples.'
70
70
  end
71
71
  if (n_test_samples + n_train_samples) > n_samples
72
72
  raise RangeError,
@@ -30,7 +30,7 @@ module Rumale
30
30
  # @return [Random]
31
31
  attr_reader :rng
32
32
 
33
- # Create a new data splitter for K-fold cross validation.
33
+ # Create a new data splitter for stratified K-fold cross validation.
34
34
  #
35
35
  # @param n_splits [Integer] The number of folds.
36
36
  # @param shuffle [Boolean] The flag indicating whether to shuffle the dataset.
@@ -66,15 +66,15 @@ module Rumale
66
66
  raise ArgumentError,
67
67
  'The value of n_splits must be not less than 1 and not more than the number of samples in each class.'
68
68
  end
69
- unless enough_data_size_each_class?(y, @test_size)
69
+ unless enough_data_size_each_class?(y, @test_size, 'test')
70
70
  raise RangeError,
71
- 'The number of sample in test split must be not less than 1 and not more than the number of samples in each class.'
71
+ 'The number of samples in test split must be not less than 1 and not more than the number of samples in each class.'
72
72
  end
73
- unless enough_data_size_each_class?(y, train_sz)
73
+ unless enough_data_size_each_class?(y, train_sz, 'train')
74
74
  raise RangeError,
75
- 'The number of sample in train split must be not less than 1 and not more than the number of samples in each class.'
75
+ 'The number of samples in train split must be not less than 1 and not more than the number of samples in each class.'
76
76
  end
77
- unless enough_data_size_each_class?(y, train_sz + @test_size)
77
+ unless enough_data_size_each_class?(y, train_sz + @test_size, 'train')
78
78
  raise RangeError,
79
79
  'The total number of samples in test split and train split must be not more than the number of samples in each class.'
80
80
  end
@@ -85,12 +85,12 @@ module Rumale
85
85
  test_ids = []
86
86
  sample_ids_each_class.each do |sample_ids|
87
87
  n_samples = sample_ids.size
88
- n_test_samples = (@test_size * n_samples).to_i
89
- n_train_samples = (train_sz * n_samples).to_i
88
+ n_test_samples = (@test_size * n_samples).ceil.to_i
90
89
  test_ids += sample_ids.sample(n_test_samples, random: sub_rng)
91
90
  train_ids += if @train_size.nil?
92
91
  sample_ids - test_ids
93
92
  else
93
+ n_train_samples = (train_sz * n_samples).floor.to_i
94
94
  (sample_ids - test_ids).sample(n_train_samples, random: sub_rng)
95
95
  end
96
96
  end
@@ -104,9 +104,13 @@ module Rumale
104
104
  y.to_a.uniq.map { |label| y.eq(label).where.size }.all? { |n_samples| @n_splits.between?(1, n_samples) }
105
105
  end
106
106
 
107
- def enough_data_size_each_class?(y, data_size)
107
+ def enough_data_size_each_class?(y, data_size, data_type)
108
108
  y.to_a.uniq.map { |label| y.eq(label).where.size }.all? do |n_samples|
109
- (data_size * n_samples).to_i.between?(1, n_samples)
109
+ if data_type == 'test'
110
+ (data_size * n_samples).ceil.to_i.between?(1, n_samples)
111
+ else
112
+ (data_size * n_samples).floor.to_i.between?(1, n_samples)
113
+ end
110
114
  end
111
115
  end
112
116
  end
@@ -98,7 +98,7 @@ module Rumale
98
98
 
99
99
  def hessian_matrix(probs, df, sigma)
100
100
  sub = probs * (1 - probs)
101
- h11 = (df * df * sub).sum + sigma
101
+ h11 = (df**2 * sub).sum + sigma
102
102
  h22 = sub.sum + sigma
103
103
  h21 = (df * sub).sum
104
104
  Numo::DFloat[[h11, h21], [h21, h22]]
@@ -3,5 +3,5 @@
3
3
  # Rumale is a machine learning library in Ruby.
4
4
  module Rumale
5
5
  # The version of Rumale you are using.
6
- VERSION = '0.20.0'
6
+ VERSION = '0.20.1'
7
7
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: rumale
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.20.0
4
+ version: 0.20.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2020-08-01 00:00:00.000000000 Z
11
+ date: 2020-08-23 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -135,6 +135,8 @@ files:
135
135
  - lib/rumale/model_selection/cross_validation.rb
136
136
  - lib/rumale/model_selection/function.rb
137
137
  - lib/rumale/model_selection/grid_search_cv.rb
138
+ - lib/rumale/model_selection/group_k_fold.rb
139
+ - lib/rumale/model_selection/group_shuffle_split.rb
138
140
  - lib/rumale/model_selection/k_fold.rb
139
141
  - lib/rumale/model_selection/shuffle_split.rb
140
142
  - lib/rumale/model_selection/stratified_k_fold.rb