rumale 0.20.0 → 0.20.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.rubocop.yml +6 -95
- data/CHANGELOG.md +9 -0
- data/lib/rumale.rb +2 -0
- data/lib/rumale/clustering/hdbscan.rb +2 -2
- data/lib/rumale/dataset.rb +1 -1
- data/lib/rumale/decomposition/nmf.rb +2 -2
- data/lib/rumale/ensemble/random_forest_classifier.rb +1 -1
- data/lib/rumale/ensemble/random_forest_regressor.rb +1 -1
- data/lib/rumale/feature_extraction/feature_hasher.rb +1 -1
- data/lib/rumale/feature_extraction/hash_vectorizer.rb +1 -1
- data/lib/rumale/manifold/tsne.rb +1 -1
- data/lib/rumale/model_selection/group_k_fold.rb +93 -0
- data/lib/rumale/model_selection/group_shuffle_split.rb +115 -0
- data/lib/rumale/model_selection/shuffle_split.rb +4 -4
- data/lib/rumale/model_selection/stratified_k_fold.rb +1 -1
- data/lib/rumale/model_selection/stratified_shuffle_split.rb +13 -9
- data/lib/rumale/probabilistic_output.rb +1 -1
- data/lib/rumale/version.rb +1 -1
- metadata +4 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 0f361026cd2922a2d36846a817eee855bf0c000156ed6c756bca29d2e42d67a2
|
4
|
+
data.tar.gz: 016fa40aa2546824cacbc32353263cbfc9427f0ceabb7e703f99854914bb9a2e
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 7a53a958db7ec8b56236018505370b9908ae81a9afc9d7c8ff0b16d83971539c1ad729b5ab350eb49ae9b90ada43a8912ed2404a37eef97a4d34dad90b1d3e9f
|
7
|
+
data.tar.gz: 2f2b3d48625c7120464179bc7759c01ba7de85cb0d54720665eaf1e4822f24c1870474ebc24a47cff123e44a8626b0e0fac6a7e81216c057286071770ea5ba79
|
data/.rubocop.yml
CHANGED
@@ -3,6 +3,7 @@ require:
|
|
3
3
|
- rubocop-rspec
|
4
4
|
|
5
5
|
AllCops:
|
6
|
+
NewCops: enable
|
6
7
|
TargetRubyVersion: 2.5
|
7
8
|
DisplayCopNames: true
|
8
9
|
DisplayStyleGuide: true
|
@@ -15,34 +16,13 @@ AllCops:
|
|
15
16
|
Style/Documentation:
|
16
17
|
Enabled: false
|
17
18
|
|
18
|
-
Style/HashEachMethods:
|
19
|
-
Enabled: true
|
20
|
-
|
21
|
-
Style/HashTransformKeys:
|
22
|
-
Enabled: true
|
23
|
-
|
24
|
-
Style/HashTransformValues:
|
25
|
-
Enabled: true
|
26
|
-
|
27
|
-
Lint/DeprecatedOpenSSLConstant:
|
28
|
-
Enabled: true
|
29
|
-
|
30
|
-
Lint/DuplicateElsifCondition:
|
31
|
-
Enabled: true
|
32
|
-
|
33
|
-
Lint/MixedRegexpCaptureTypes:
|
34
|
-
Enabled: true
|
35
|
-
|
36
|
-
Lint/RaiseException:
|
37
|
-
Enabled: true
|
38
|
-
|
39
|
-
Lint/StructNewOverride:
|
40
|
-
Enabled: true
|
41
|
-
|
42
19
|
Layout/LineLength:
|
43
20
|
Max: 145
|
44
21
|
IgnoredPatterns: ['(\A|\s)#']
|
45
22
|
|
23
|
+
Lint/MissingSuper:
|
24
|
+
Enabled: false
|
25
|
+
|
46
26
|
Metrics/ModuleLength:
|
47
27
|
Max: 200
|
48
28
|
|
@@ -78,83 +58,14 @@ Naming/MethodParameterName:
|
|
78
58
|
Naming/ConstantName:
|
79
59
|
Enabled: false
|
80
60
|
|
81
|
-
Style/AccessorGrouping:
|
82
|
-
Enabled: true
|
83
|
-
|
84
|
-
Style/ArrayCoercion:
|
85
|
-
Enabled: true
|
86
|
-
|
87
|
-
Style/BisectedAttrAccessor:
|
88
|
-
Enabled: true
|
89
|
-
|
90
|
-
Style/CaseLikeIf:
|
91
|
-
Enabled: true
|
92
|
-
|
93
|
-
Style/ExponentialNotation:
|
94
|
-
Enabled: true
|
95
|
-
|
96
61
|
Style/FormatStringToken:
|
97
62
|
Enabled: false
|
98
63
|
|
99
|
-
Style/HashAsLastArrayItem:
|
100
|
-
Enabled: true
|
101
|
-
|
102
|
-
Style/HashLikeCase:
|
103
|
-
Enabled: true
|
104
|
-
|
105
64
|
Style/NumericLiterals:
|
106
65
|
Enabled: false
|
107
66
|
|
108
|
-
Style/
|
109
|
-
Enabled:
|
110
|
-
|
111
|
-
Style/RedundantFetchBlock:
|
112
|
-
Enabled: true
|
113
|
-
|
114
|
-
Style/RedundantFileExtensionInRequire:
|
115
|
-
Enabled: true
|
116
|
-
|
117
|
-
Style/RedundantRegexpCharacterClass:
|
118
|
-
Enabled: true
|
119
|
-
|
120
|
-
Style/RedundantRegexpEscape:
|
121
|
-
Enabled: true
|
122
|
-
|
123
|
-
Style/SlicingWithRange:
|
124
|
-
Enabled: true
|
125
|
-
|
126
|
-
Layout/EmptyLineAfterGuardClause:
|
127
|
-
Enabled: true
|
128
|
-
|
129
|
-
Layout/EmptyLinesAroundAttributeAccessor:
|
130
|
-
Enabled: true
|
131
|
-
|
132
|
-
Layout/SpaceAroundMethodCallOperator:
|
133
|
-
Enabled: true
|
134
|
-
|
135
|
-
Performance/AncestorsInclude:
|
136
|
-
Enabled: true
|
137
|
-
|
138
|
-
Performance/BigDecimalWithNumericArgument:
|
139
|
-
Enabled: true
|
140
|
-
|
141
|
-
Performance/RedundantSortBlock:
|
142
|
-
Enabled: true
|
143
|
-
|
144
|
-
Performance/RedundantStringChars:
|
145
|
-
Enabled: true
|
146
|
-
|
147
|
-
Performance/ReverseFirst:
|
148
|
-
Enabled: true
|
149
|
-
|
150
|
-
Performance/SortReverse:
|
151
|
-
Enabled: true
|
152
|
-
|
153
|
-
Performance/Squeeze:
|
154
|
-
Enabled: true
|
155
|
-
|
156
|
-
Performance/StringInclude:
|
157
|
-
Enabled: true
|
67
|
+
Style/StringConcatenation:
|
68
|
+
Enabled: false
|
158
69
|
|
159
70
|
RSpec/MultipleExpectations:
|
160
71
|
Enabled: false
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,12 @@
|
|
1
|
+
# 0.20.1
|
2
|
+
- Add cross-validator classes that split data according group labels.
|
3
|
+
- [GroupKFold](https://yoshoku.github.io/rumale/doc/Rumale/ModelSelection/GroupKFold.html)
|
4
|
+
- [GroupShuffleSplit](https://yoshoku.github.io/rumale/doc/Rumale/ModelSelection/GroupShuffleSplit.html)
|
5
|
+
- Fix fraction treating of the number of samples on shuffle split cross-validator classes.
|
6
|
+
- [ShuffleSplit](https://yoshoku.github.io/rumale/doc/Rumale/ModelSelection/ShuffleSplit.html)
|
7
|
+
- [StratifiedShuffleSplit](https://yoshoku.github.io/rumale/doc/Rumale/ModelSelection/StratifiedShuffleSplit.html)
|
8
|
+
- Refactor some codes with Rubocop.
|
9
|
+
|
1
10
|
# 0.20.0
|
2
11
|
## Breaking changes
|
3
12
|
- Delete deprecated estimators such as PolynomialModel, Optimizer, and BaseLinearModel.
|
data/lib/rumale.rb
CHANGED
@@ -98,8 +98,10 @@ require 'rumale/preprocessing/ordinal_encoder'
|
|
98
98
|
require 'rumale/preprocessing/binarizer'
|
99
99
|
require 'rumale/preprocessing/polynomial_features'
|
100
100
|
require 'rumale/model_selection/k_fold'
|
101
|
+
require 'rumale/model_selection/group_k_fold'
|
101
102
|
require 'rumale/model_selection/stratified_k_fold'
|
102
103
|
require 'rumale/model_selection/shuffle_split'
|
104
|
+
require 'rumale/model_selection/group_shuffle_split'
|
103
105
|
require 'rumale/model_selection/stratified_shuffle_split'
|
104
106
|
require 'rumale/model_selection/cross_validation'
|
105
107
|
require 'rumale/model_selection/grid_search_cv'
|
@@ -136,7 +136,7 @@ module Rumale
|
|
136
136
|
res
|
137
137
|
end
|
138
138
|
|
139
|
-
# rubocop:disable Metrics/AbcSize, Metrics/MethodLength, Metrics/PerceivedComplexity
|
139
|
+
# rubocop:disable Metrics/AbcSize, Metrics/CyclomaticComplexity, Metrics/MethodLength, Metrics/PerceivedComplexity
|
140
140
|
def condense_tree(hierarchy, min_cluster_size)
|
141
141
|
n_edges = hierarchy.size
|
142
142
|
root = 2 * n_edges
|
@@ -265,7 +265,7 @@ module Rumale
|
|
265
265
|
end
|
266
266
|
res
|
267
267
|
end
|
268
|
-
# rubocop:enable Metrics/AbcSize, Metrics/MethodLength, Metrics/PerceivedComplexity
|
268
|
+
# rubocop:enable Metrics/AbcSize, Metrics/CyclomaticComplexity, Metrics/MethodLength, Metrics/PerceivedComplexity
|
269
269
|
end
|
270
270
|
end
|
271
271
|
end
|
data/lib/rumale/dataset.rb
CHANGED
@@ -225,7 +225,7 @@ module Rumale
|
|
225
225
|
line = dump_label(label, label_type.to_s)
|
226
226
|
ftvec.to_a.each_with_index do |val, n|
|
227
227
|
idx = n + (zero_based == false ? 1 : 0)
|
228
|
-
line += format(" %d:#{value_type}", idx, val) if val != 0
|
228
|
+
line += format(" %d:#{value_type}", idx, val) if val != 0
|
229
229
|
end
|
230
230
|
line
|
231
231
|
end
|
@@ -77,7 +77,7 @@ module Rumale
|
|
77
77
|
# @return [Numo::DFloat] (shape: [n_samples, n_components]) The transformed data.
|
78
78
|
def transform(x)
|
79
79
|
x = check_convert_sample_array(x)
|
80
|
-
partial_fit(x, false)
|
80
|
+
partial_fit(x, update_comps: false)
|
81
81
|
end
|
82
82
|
|
83
83
|
# Inverse transform the given transformed data with the learned model.
|
@@ -91,7 +91,7 @@ module Rumale
|
|
91
91
|
|
92
92
|
private
|
93
93
|
|
94
|
-
def partial_fit(x, update_comps
|
94
|
+
def partial_fit(x, update_comps: true)
|
95
95
|
# initialize some variables.
|
96
96
|
n_samples, n_features = x.shape
|
97
97
|
scale = Math.sqrt(x.mean / @params[:n_components])
|
@@ -85,7 +85,7 @@ module Rumale
|
|
85
85
|
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
|
86
86
|
# @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
|
87
87
|
# @return [RandomForestClassifier] The learned classifier itself.
|
88
|
-
def fit(x, y)
|
88
|
+
def fit(x, y) # rubocop:disable Metrics/AbcSize
|
89
89
|
x = check_convert_sample_array(x)
|
90
90
|
y = check_convert_label_array(y)
|
91
91
|
check_sample_label_size(x, y)
|
@@ -79,7 +79,7 @@ module Rumale
|
|
79
79
|
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
|
80
80
|
# @param y [Numo::DFloat] (shape: [n_samples, n_outputs]) The target values to be used for fitting the model.
|
81
81
|
# @return [RandomForestRegressor] The learned regressor itself.
|
82
|
-
def fit(x, y)
|
82
|
+
def fit(x, y) # rubocop:disable Metrics/AbcSize
|
83
83
|
x = check_convert_sample_array(x)
|
84
84
|
y = check_convert_tvalue_array(y)
|
85
85
|
check_sample_tvalue_size(x, y)
|
@@ -67,7 +67,7 @@ module Rumale
|
|
67
67
|
def transform(x)
|
68
68
|
raise 'FeatureHasher#transform requires Mmh3 but that is not loaded.' unless enable_mmh3?
|
69
69
|
|
70
|
-
x = [x] unless x.is_a?(Array)
|
70
|
+
x = [x] unless x.is_a?(Array) # rubocop:disable Style/ArrayCoercion
|
71
71
|
n_samples = x.size
|
72
72
|
|
73
73
|
z = Numo::DFloat.zeros(n_samples, n_features)
|
@@ -99,7 +99,7 @@ module Rumale
|
|
99
99
|
# @param x [Array<Hash>] (shape: [n_samples]) The array of hash consisting of feature names and values.
|
100
100
|
# @return [Numo::DFloat] (shape: [n_samples, n_features]) The encoded sample array.
|
101
101
|
def transform(x)
|
102
|
-
x = [x] unless x.is_a?(Array)
|
102
|
+
x = [x] unless x.is_a?(Array) # rubocop:disable Style/ArrayCoercion
|
103
103
|
n_samples = x.size
|
104
104
|
n_features = @vocabulary.size
|
105
105
|
z = Numo::DFloat.zeros(n_samples, n_features)
|
data/lib/rumale/manifold/tsne.rb
CHANGED
@@ -102,7 +102,7 @@ module Rumale
|
|
102
102
|
break if terminate?(hi_prob_mat, lo_prob_mat)
|
103
103
|
|
104
104
|
a = hi_prob_mat * lo_prob_mat
|
105
|
-
b = lo_prob_mat
|
105
|
+
b = lo_prob_mat**2
|
106
106
|
y = (b.dot(one_vec) * y + (a - b).dot(y)) / a.dot(one_vec)
|
107
107
|
lo_prob_mat = t_distributed_probability_matrix(y)
|
108
108
|
@n_iter = t + 1
|
@@ -0,0 +1,93 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'rumale/base/splitter'
|
4
|
+
require 'rumale/preprocessing/label_encoder'
|
5
|
+
|
6
|
+
module Rumale
|
7
|
+
module ModelSelection
|
8
|
+
# GroupKFold is a class that generates the set of data indices for K-fold cross-validation.
|
9
|
+
# The data points belonging to the same group do not be split into different folds.
|
10
|
+
# The number of groups should be greater than or equal to the number of splits.
|
11
|
+
#
|
12
|
+
# @example
|
13
|
+
# cv = Rumale::ModelSelection::GroupKFold.new(n_splits: 3)
|
14
|
+
# x = Numo::DFloat.new(8, 2).rand
|
15
|
+
# groups = Numo::Int32[1, 1, 1, 2, 2, 3, 3, 3]
|
16
|
+
# cv.split(x, nil, groups).each do |train_ids, test_ids|
|
17
|
+
# puts '---'
|
18
|
+
# pp train_ids
|
19
|
+
# pp test_ids
|
20
|
+
# end
|
21
|
+
#
|
22
|
+
# # ---
|
23
|
+
# # [0, 1, 2, 3, 4]
|
24
|
+
# # [5, 6, 7]
|
25
|
+
# # ---
|
26
|
+
# # [3, 4, 5, 6, 7]
|
27
|
+
# # [0, 1, 2]
|
28
|
+
# # ---
|
29
|
+
# # [0, 1, 2, 5, 6, 7]
|
30
|
+
# # [3, 4]
|
31
|
+
#
|
32
|
+
class GroupKFold
|
33
|
+
include Base::Splitter
|
34
|
+
|
35
|
+
# Return the number of folds.
|
36
|
+
# @return [Integer]
|
37
|
+
attr_reader :n_splits
|
38
|
+
|
39
|
+
# Create a new data splitter for grouped K-fold cross validation.
|
40
|
+
#
|
41
|
+
# @param n_splits [Integer] The number of folds.
|
42
|
+
def initialize(n_splits: 5)
|
43
|
+
check_params_numeric(n_splits: n_splits)
|
44
|
+
@n_splits = n_splits
|
45
|
+
end
|
46
|
+
|
47
|
+
# Generate data indices for grouped K-fold cross validation.
|
48
|
+
#
|
49
|
+
# @overload split(x, y, groups) -> Array
|
50
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features])
|
51
|
+
# The dataset to be used to generate data indices for grouped K-fold cross validation.
|
52
|
+
# @param y [Numo::Int32] (shape: [n_samples])
|
53
|
+
# This argument exists to unify the interface between the K-fold methods, it is not used in the method.
|
54
|
+
# @param groups [Numo::Int32] (shape: [n_samples])
|
55
|
+
# The group labels to be used to generate data indices for grouped K-fold cross validation.
|
56
|
+
# @return [Array] The set of data indices for constructing the training and testing dataset in each fold.
|
57
|
+
def split(x, _y, groups)
|
58
|
+
x = check_convert_sample_array(x)
|
59
|
+
groups = check_convert_label_array(groups)
|
60
|
+
check_sample_label_size(x, groups)
|
61
|
+
|
62
|
+
encoder = Rumale::Preprocessing::LabelEncoder.new
|
63
|
+
groups = encoder.fit_transform(groups)
|
64
|
+
n_groups = encoder.classes.size
|
65
|
+
|
66
|
+
raise ArgumentError, 'The number of groups should be greater than or equal to the number of splits.' if n_groups < @n_splits
|
67
|
+
|
68
|
+
n_samples_per_group = groups.bincount
|
69
|
+
group_ids = n_samples_per_group.sort_index.reverse
|
70
|
+
n_samples_per_group = n_samples_per_group[group_ids]
|
71
|
+
|
72
|
+
n_samples_per_fold = Numo::Int32.zeros(@n_splits)
|
73
|
+
group_to_fold = Numo::Int32.zeros(n_groups)
|
74
|
+
|
75
|
+
n_samples_per_group.each_with_index do |weight, id|
|
76
|
+
min_sample_fold_id = n_samples_per_fold.min_index
|
77
|
+
n_samples_per_fold[min_sample_fold_id] += weight
|
78
|
+
group_to_fold[group_ids[id]] = min_sample_fold_id
|
79
|
+
end
|
80
|
+
|
81
|
+
n_samples = x.shape[0]
|
82
|
+
sample_ids = Array(0...n_samples)
|
83
|
+
fold_ids = group_to_fold[groups]
|
84
|
+
|
85
|
+
Array.new(@n_splits) do |fid|
|
86
|
+
test_ids = fold_ids.eq(fid).where.to_a
|
87
|
+
train_ids = sample_ids - test_ids
|
88
|
+
[train_ids, test_ids]
|
89
|
+
end
|
90
|
+
end
|
91
|
+
end
|
92
|
+
end
|
93
|
+
end
|
@@ -0,0 +1,115 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'rumale/base/splitter'
|
4
|
+
|
5
|
+
module Rumale
|
6
|
+
module ModelSelection
|
7
|
+
# GroupShuffleSplit is a class that generates the set of data indices
|
8
|
+
# for random permutation cross-validation by randomly selecting group labels.
|
9
|
+
#
|
10
|
+
# @example
|
11
|
+
# cv = Rumale::ModelSelection::GroupShuffleSplit.new(n_splits: 2, test_size: 0.2, random_seed: 1)
|
12
|
+
# x = Numo::DFloat.new(8, 2).rand
|
13
|
+
# groups = Numo::Int32[1, 1, 1, 2, 2, 3, 3, 3]
|
14
|
+
# cv.split(x, nil, groups).each do |train_ids, test_ids|
|
15
|
+
# puts '---'
|
16
|
+
# pp train_ids
|
17
|
+
# pp test_ids
|
18
|
+
# end
|
19
|
+
#
|
20
|
+
# # ---
|
21
|
+
# # [0, 1, 2, 5, 6, 7]
|
22
|
+
# # [3, 4]
|
23
|
+
# # ---
|
24
|
+
# # [3, 4, 5, 6, 7]
|
25
|
+
# # [0, 1, 2]
|
26
|
+
#
|
27
|
+
class GroupShuffleSplit
|
28
|
+
include Base::Splitter
|
29
|
+
|
30
|
+
# Return the number of folds.
|
31
|
+
# @return [Integer]
|
32
|
+
attr_reader :n_splits
|
33
|
+
|
34
|
+
# Return the random generator for shuffling the dataset.
|
35
|
+
# @return [Random]
|
36
|
+
attr_reader :rng
|
37
|
+
|
38
|
+
# Create a new data splitter for random permutation cross validation with given group labels.
|
39
|
+
#
|
40
|
+
# @param n_splits [Integer] The number of folds.
|
41
|
+
# @param test_size [Float] The ratio of number of groups for test data.
|
42
|
+
# @param train_size [Float/Nil] The ratio of number of groups for train data.
|
43
|
+
# @param random_seed [Integer] The seed value using to initialize the random generator.
|
44
|
+
def initialize(n_splits: 5, test_size: 0.2, train_size: nil, random_seed: nil)
|
45
|
+
check_params_numeric(n_splits: n_splits, test_size: test_size)
|
46
|
+
check_params_numeric_or_nil(train_size: train_size, random_seed: random_seed)
|
47
|
+
check_params_positive(n_splits: n_splits)
|
48
|
+
check_params_positive(test_size: test_size)
|
49
|
+
check_params_positive(train_size: train_size) unless train_size.nil?
|
50
|
+
@n_splits = n_splits
|
51
|
+
@test_size = test_size
|
52
|
+
@train_size = train_size
|
53
|
+
@random_seed = random_seed
|
54
|
+
@random_seed ||= srand
|
55
|
+
@rng = Random.new(@random_seed)
|
56
|
+
end
|
57
|
+
|
58
|
+
# Generate train and test data indices by randomly selecting group labels.
|
59
|
+
#
|
60
|
+
# @overload split(x, y, groups) -> Array
|
61
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features])
|
62
|
+
# The dataset to be used to generate data indices for random permutation cross validation.
|
63
|
+
# @param y [Numo::Int32] (shape: [n_samples])
|
64
|
+
# This argument exists to unify the interface between the K-fold methods, it is not used in the method.
|
65
|
+
# @param groups [Numo::Int32] (shape: [n_samples])
|
66
|
+
# The group labels to be used to generate data indices for random permutation cross validation.
|
67
|
+
# @return [Array] The set of data indices for constructing the training and testing dataset in each fold.
|
68
|
+
def split(x, _y, groups)
|
69
|
+
x = check_convert_sample_array(x)
|
70
|
+
groups = check_convert_label_array(groups)
|
71
|
+
check_sample_label_size(x, groups)
|
72
|
+
|
73
|
+
classes = groups.to_a.uniq.sort
|
74
|
+
n_groups = classes.size
|
75
|
+
n_test_groups = (@test_size * n_groups).ceil.to_i
|
76
|
+
n_train_groups = @train_size.nil? ? n_groups - n_test_groups : (@train_size * n_groups).floor.to_i
|
77
|
+
|
78
|
+
unless n_test_groups.between?(1, n_groups)
|
79
|
+
raise RangeError,
|
80
|
+
'The number of groups in test split must be not less than 1 and not more than the number of groups.'
|
81
|
+
end
|
82
|
+
unless n_train_groups.between?(1, n_groups)
|
83
|
+
raise RangeError,
|
84
|
+
'The number of groups in train split must be not less than 1 and not more than the number of groups.'
|
85
|
+
end
|
86
|
+
if (n_test_groups + n_train_groups) > n_groups
|
87
|
+
raise RangeError,
|
88
|
+
'The total number of groups in test split and train split must be not more than the number of groups.'
|
89
|
+
end
|
90
|
+
|
91
|
+
sub_rng = @rng.dup
|
92
|
+
|
93
|
+
Array.new(@n_splits) do
|
94
|
+
test_group_ids = classes.sample(n_test_groups, random: sub_rng)
|
95
|
+
train_group_ids = if @train_size.nil?
|
96
|
+
classes - test_group_ids
|
97
|
+
else
|
98
|
+
(classes - test_group_ids).sample(n_train_groups, random: sub_rng)
|
99
|
+
end
|
100
|
+
test_ids = in1d(groups, test_group_ids).where.to_a
|
101
|
+
train_ids = in1d(groups, train_group_ids).where.to_a
|
102
|
+
[train_ids, test_ids]
|
103
|
+
end
|
104
|
+
end
|
105
|
+
|
106
|
+
private
|
107
|
+
|
108
|
+
def in1d(a, b)
|
109
|
+
res = Numo::Bit.zeros(a.shape[0])
|
110
|
+
b.each { |v| res |= a.eq(v) }
|
111
|
+
res
|
112
|
+
end
|
113
|
+
end
|
114
|
+
end
|
115
|
+
end
|
@@ -54,19 +54,19 @@ module Rumale
|
|
54
54
|
x = check_convert_sample_array(x)
|
55
55
|
# Initialize and check some variables.
|
56
56
|
n_samples = x.shape[0]
|
57
|
-
n_test_samples = (@test_size * n_samples).to_i
|
58
|
-
n_train_samples = @train_size.nil? ? n_samples - n_test_samples : (@train_size * n_samples).to_i
|
57
|
+
n_test_samples = (@test_size * n_samples).ceil.to_i
|
58
|
+
n_train_samples = @train_size.nil? ? n_samples - n_test_samples : (@train_size * n_samples).floor.to_i
|
59
59
|
unless @n_splits.between?(1, n_samples)
|
60
60
|
raise ArgumentError,
|
61
61
|
'The value of n_splits must be not less than 1 and not more than the number of samples.'
|
62
62
|
end
|
63
63
|
unless n_test_samples.between?(1, n_samples)
|
64
64
|
raise RangeError,
|
65
|
-
'The number of
|
65
|
+
'The number of samples in test split must be not less than 1 and not more than the number of samples.'
|
66
66
|
end
|
67
67
|
unless n_train_samples.between?(1, n_samples)
|
68
68
|
raise RangeError,
|
69
|
-
'The number of
|
69
|
+
'The number of samples in train split must be not less than 1 and not more than the number of samples.'
|
70
70
|
end
|
71
71
|
if (n_test_samples + n_train_samples) > n_samples
|
72
72
|
raise RangeError,
|
@@ -30,7 +30,7 @@ module Rumale
|
|
30
30
|
# @return [Random]
|
31
31
|
attr_reader :rng
|
32
32
|
|
33
|
-
# Create a new data splitter for K-fold cross validation.
|
33
|
+
# Create a new data splitter for stratified K-fold cross validation.
|
34
34
|
#
|
35
35
|
# @param n_splits [Integer] The number of folds.
|
36
36
|
# @param shuffle [Boolean] The flag indicating whether to shuffle the dataset.
|
@@ -66,15 +66,15 @@ module Rumale
|
|
66
66
|
raise ArgumentError,
|
67
67
|
'The value of n_splits must be not less than 1 and not more than the number of samples in each class.'
|
68
68
|
end
|
69
|
-
unless enough_data_size_each_class?(y, @test_size)
|
69
|
+
unless enough_data_size_each_class?(y, @test_size, 'test')
|
70
70
|
raise RangeError,
|
71
|
-
'The number of
|
71
|
+
'The number of samples in test split must be not less than 1 and not more than the number of samples in each class.'
|
72
72
|
end
|
73
|
-
unless enough_data_size_each_class?(y, train_sz)
|
73
|
+
unless enough_data_size_each_class?(y, train_sz, 'train')
|
74
74
|
raise RangeError,
|
75
|
-
'The number of
|
75
|
+
'The number of samples in train split must be not less than 1 and not more than the number of samples in each class.'
|
76
76
|
end
|
77
|
-
unless enough_data_size_each_class?(y, train_sz + @test_size)
|
77
|
+
unless enough_data_size_each_class?(y, train_sz + @test_size, 'train')
|
78
78
|
raise RangeError,
|
79
79
|
'The total number of samples in test split and train split must be not more than the number of samples in each class.'
|
80
80
|
end
|
@@ -85,12 +85,12 @@ module Rumale
|
|
85
85
|
test_ids = []
|
86
86
|
sample_ids_each_class.each do |sample_ids|
|
87
87
|
n_samples = sample_ids.size
|
88
|
-
n_test_samples = (@test_size * n_samples).to_i
|
89
|
-
n_train_samples = (train_sz * n_samples).to_i
|
88
|
+
n_test_samples = (@test_size * n_samples).ceil.to_i
|
90
89
|
test_ids += sample_ids.sample(n_test_samples, random: sub_rng)
|
91
90
|
train_ids += if @train_size.nil?
|
92
91
|
sample_ids - test_ids
|
93
92
|
else
|
93
|
+
n_train_samples = (train_sz * n_samples).floor.to_i
|
94
94
|
(sample_ids - test_ids).sample(n_train_samples, random: sub_rng)
|
95
95
|
end
|
96
96
|
end
|
@@ -104,9 +104,13 @@ module Rumale
|
|
104
104
|
y.to_a.uniq.map { |label| y.eq(label).where.size }.all? { |n_samples| @n_splits.between?(1, n_samples) }
|
105
105
|
end
|
106
106
|
|
107
|
-
def enough_data_size_each_class?(y, data_size)
|
107
|
+
def enough_data_size_each_class?(y, data_size, data_type)
|
108
108
|
y.to_a.uniq.map { |label| y.eq(label).where.size }.all? do |n_samples|
|
109
|
-
|
109
|
+
if data_type == 'test'
|
110
|
+
(data_size * n_samples).ceil.to_i.between?(1, n_samples)
|
111
|
+
else
|
112
|
+
(data_size * n_samples).floor.to_i.between?(1, n_samples)
|
113
|
+
end
|
110
114
|
end
|
111
115
|
end
|
112
116
|
end
|
data/lib/rumale/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: rumale
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.20.
|
4
|
+
version: 0.20.1
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- yoshoku
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2020-08-
|
11
|
+
date: 2020-08-23 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|
@@ -135,6 +135,8 @@ files:
|
|
135
135
|
- lib/rumale/model_selection/cross_validation.rb
|
136
136
|
- lib/rumale/model_selection/function.rb
|
137
137
|
- lib/rumale/model_selection/grid_search_cv.rb
|
138
|
+
- lib/rumale/model_selection/group_k_fold.rb
|
139
|
+
- lib/rumale/model_selection/group_shuffle_split.rb
|
138
140
|
- lib/rumale/model_selection/k_fold.rb
|
139
141
|
- lib/rumale/model_selection/shuffle_split.rb
|
140
142
|
- lib/rumale/model_selection/stratified_k_fold.rb
|