svmkit 0.7.3 → 0.8.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.gitignore +0 -9
- data/.rspec +1 -0
- data/.travis.yml +4 -12
- data/LICENSE.txt +1 -1
- data/README.md +11 -13
- data/lib/svmkit.rb +3 -66
- data/svmkit.gemspec +12 -7
- metadata +16 -81
- data/.coveralls.yml +0 -1
- data/.rubocop.yml +0 -47
- data/.rubocop_todo.yml +0 -58
- data/HISTORY.md +0 -168
- data/lib/svmkit/base/base_estimator.rb +0 -13
- data/lib/svmkit/base/classifier.rb +0 -34
- data/lib/svmkit/base/cluster_analyzer.rb +0 -29
- data/lib/svmkit/base/evaluator.rb +0 -13
- data/lib/svmkit/base/regressor.rb +0 -34
- data/lib/svmkit/base/splitter.rb +0 -17
- data/lib/svmkit/base/transformer.rb +0 -18
- data/lib/svmkit/clustering/dbscan.rb +0 -127
- data/lib/svmkit/clustering/k_means.rb +0 -140
- data/lib/svmkit/dataset.rb +0 -109
- data/lib/svmkit/decomposition/nmf.rb +0 -147
- data/lib/svmkit/decomposition/pca.rb +0 -150
- data/lib/svmkit/ensemble/ada_boost_classifier.rb +0 -198
- data/lib/svmkit/ensemble/ada_boost_regressor.rb +0 -180
- data/lib/svmkit/ensemble/random_forest_classifier.rb +0 -182
- data/lib/svmkit/ensemble/random_forest_regressor.rb +0 -143
- data/lib/svmkit/evaluation_measure/accuracy.rb +0 -30
- data/lib/svmkit/evaluation_measure/f_score.rb +0 -51
- data/lib/svmkit/evaluation_measure/log_loss.rb +0 -46
- data/lib/svmkit/evaluation_measure/mean_absolute_error.rb +0 -30
- data/lib/svmkit/evaluation_measure/mean_squared_error.rb +0 -30
- data/lib/svmkit/evaluation_measure/normalized_mutual_information.rb +0 -63
- data/lib/svmkit/evaluation_measure/precision.rb +0 -51
- data/lib/svmkit/evaluation_measure/precision_recall.rb +0 -91
- data/lib/svmkit/evaluation_measure/purity.rb +0 -41
- data/lib/svmkit/evaluation_measure/r2_score.rb +0 -44
- data/lib/svmkit/evaluation_measure/recall.rb +0 -51
- data/lib/svmkit/kernel_approximation/rbf.rb +0 -136
- data/lib/svmkit/kernel_machine/kernel_svc.rb +0 -194
- data/lib/svmkit/linear_model/lasso.rb +0 -138
- data/lib/svmkit/linear_model/linear_regression.rb +0 -112
- data/lib/svmkit/linear_model/logistic_regression.rb +0 -161
- data/lib/svmkit/linear_model/ridge.rb +0 -112
- data/lib/svmkit/linear_model/sgd_linear_estimator.rb +0 -89
- data/lib/svmkit/linear_model/svc.rb +0 -184
- data/lib/svmkit/linear_model/svr.rb +0 -123
- data/lib/svmkit/model_selection/cross_validation.rb +0 -121
- data/lib/svmkit/model_selection/grid_search_cv.rb +0 -247
- data/lib/svmkit/model_selection/k_fold.rb +0 -77
- data/lib/svmkit/model_selection/stratified_k_fold.rb +0 -95
- data/lib/svmkit/multiclass/one_vs_rest_classifier.rb +0 -101
- data/lib/svmkit/naive_bayes/naive_bayes.rb +0 -316
- data/lib/svmkit/nearest_neighbors/k_neighbors_classifier.rb +0 -112
- data/lib/svmkit/nearest_neighbors/k_neighbors_regressor.rb +0 -94
- data/lib/svmkit/optimizer/nadam.rb +0 -90
- data/lib/svmkit/optimizer/rmsprop.rb +0 -69
- data/lib/svmkit/optimizer/sgd.rb +0 -65
- data/lib/svmkit/optimizer/yellow_fin.rb +0 -144
- data/lib/svmkit/pairwise_metric.rb +0 -91
- data/lib/svmkit/pipeline/pipeline.rb +0 -197
- data/lib/svmkit/polynomial_model/factorization_machine_classifier.rb +0 -262
- data/lib/svmkit/polynomial_model/factorization_machine_regressor.rb +0 -194
- data/lib/svmkit/preprocessing/l2_normalizer.rb +0 -63
- data/lib/svmkit/preprocessing/label_encoder.rb +0 -95
- data/lib/svmkit/preprocessing/min_max_scaler.rb +0 -93
- data/lib/svmkit/preprocessing/one_hot_encoder.rb +0 -99
- data/lib/svmkit/preprocessing/standard_scaler.rb +0 -87
- data/lib/svmkit/probabilistic_output.rb +0 -112
- data/lib/svmkit/tree/decision_tree_classifier.rb +0 -276
- data/lib/svmkit/tree/decision_tree_regressor.rb +0 -251
- data/lib/svmkit/tree/node.rb +0 -70
- data/lib/svmkit/utils.rb +0 -22
- data/lib/svmkit/validation.rb +0 -79
- data/lib/svmkit/values.rb +0 -13
- data/lib/svmkit/version.rb +0 -7
@@ -1,247 +0,0 @@
|
|
1
|
-
# frozen_string_literal: true
|
2
|
-
|
3
|
-
require 'svmkit/validation'
|
4
|
-
require 'svmkit/base/base_estimator'
|
5
|
-
require 'svmkit/base/evaluator'
|
6
|
-
require 'svmkit/base/splitter'
|
7
|
-
require 'svmkit/pipeline/pipeline'
|
8
|
-
|
9
|
-
module SVMKit
|
10
|
-
module ModelSelection
|
11
|
-
# GridSearchCV is a class that performs hyperparameter optimization with grid search method.
|
12
|
-
#
|
13
|
-
# @example
|
14
|
-
# rfc = SVMKit::Ensemble::RandomForestClassifier.new(random_seed: 1)
|
15
|
-
# pg = { n_estimators: [5, 10], max_depth: [3, 5], max_leaf_nodes: [15, 31] }
|
16
|
-
# kf = SVMKit::ModelSelection::StratifiedKFold.new(n_splits: 5)
|
17
|
-
# gs = SVMKit::ModelSelection::GridSearchCV.new(estimator: rfc, param_grid: pg, splitter: kf)
|
18
|
-
# gs.fit(samples, labels)
|
19
|
-
# p gs.cv_results
|
20
|
-
# p gs.best_params
|
21
|
-
#
|
22
|
-
# @example
|
23
|
-
# rbf = SVMKit::KernelApproximation::RBF.new(random_seed: 1)
|
24
|
-
# svc = SVMKit::LinearModel::SVC.new(random_seed: 1)
|
25
|
-
# pipe = SVMKit::Pipeline::Pipeline.new(steps: { rbf: rbf, svc: svc })
|
26
|
-
# pg = { rbf__gamma: [32.0, 1.0], rbf__n_components: [4, 128], svc__reg_param: [16.0, 0.1] }
|
27
|
-
# kf = SVMKit::ModelSelection::StratifiedKFold.new(n_splits: 5)
|
28
|
-
# gs = SVMKit::ModelSelection::GridSearchCV.new(estimator: pipe, param_grid: pg, splitter: kf)
|
29
|
-
# gs.fit(samples, labels)
|
30
|
-
# p gs.cv_results
|
31
|
-
# p gs.best_params
|
32
|
-
#
|
33
|
-
class GridSearchCV
|
34
|
-
include Base::BaseEstimator
|
35
|
-
include Validation
|
36
|
-
|
37
|
-
# Return the result of cross validation for each parameter.
|
38
|
-
# @return [Hash]
|
39
|
-
attr_reader :cv_results
|
40
|
-
|
41
|
-
# Return the score of the estimator learned with the best parameter.
|
42
|
-
# @return [Float]
|
43
|
-
attr_reader :best_score
|
44
|
-
|
45
|
-
# Return the best parameter set.
|
46
|
-
# @return [Hash]
|
47
|
-
attr_reader :best_params
|
48
|
-
|
49
|
-
# Return the index of the best parameter.
|
50
|
-
# @return [Integer]
|
51
|
-
attr_reader :best_index
|
52
|
-
|
53
|
-
# Return the estimator learned with the best parameter.
|
54
|
-
# @return [Estimator]
|
55
|
-
attr_reader :best_estimator
|
56
|
-
|
57
|
-
# Create a new grid search method.
|
58
|
-
#
|
59
|
-
# @param estimator [Classifier/Regresor] The estimator to be searched for optimal parameters with grid search method.
|
60
|
-
# @param param_grid [Array<Hash>] The parameter sets is represented with array of hash that
|
61
|
-
# consists of parameter names as keys and array of parameter values as values.
|
62
|
-
# @param splitter [Splitter] The splitter that divides dataset to training and testing dataset on cross validation.
|
63
|
-
# @param evaluator [Evaluator] The evaluator that calculates score of estimator results on cross validation.
|
64
|
-
# If nil is given, the score method of estimator is used to evaluation.
|
65
|
-
# @param greater_is_better [Boolean] The flag that indicates whether the estimator is better as
|
66
|
-
# evaluation score is larger.
|
67
|
-
def initialize(estimator: nil, param_grid: nil, splitter: nil, evaluator: nil, greater_is_better: true)
|
68
|
-
check_params_type(SVMKit::Base::BaseEstimator, estimator: estimator)
|
69
|
-
check_params_type(SVMKit::Base::Splitter, splitter: splitter)
|
70
|
-
check_params_type_or_nil(SVMKit::Base::Evaluator, evaluator: evaluator)
|
71
|
-
check_params_boolean(greater_is_better: greater_is_better)
|
72
|
-
@params = {}
|
73
|
-
@params[:param_grid] = valid_param_grid(param_grid)
|
74
|
-
@params[:estimator] = Marshal.load(Marshal.dump(estimator))
|
75
|
-
@params[:splitter] = Marshal.load(Marshal.dump(splitter))
|
76
|
-
@params[:evaluator] = Marshal.load(Marshal.dump(evaluator))
|
77
|
-
@params[:greater_is_better] = greater_is_better
|
78
|
-
@cv_results = nil
|
79
|
-
@best_score = nil
|
80
|
-
@best_params = nil
|
81
|
-
@best_index = nil
|
82
|
-
@best_estimator = nil
|
83
|
-
end
|
84
|
-
|
85
|
-
# Fit the model with given training data and all sets of parameters.
|
86
|
-
#
|
87
|
-
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
|
88
|
-
# @param y [Numo::NArray] (shape: [n_samples, n_outputs]) The target values or labels to be used for fitting the model.
|
89
|
-
# @return [GridSearchCV] The learned estimator with grid search.
|
90
|
-
def fit(x, y)
|
91
|
-
check_sample_array(x)
|
92
|
-
|
93
|
-
init_attrs
|
94
|
-
|
95
|
-
param_combinations.each do |prm_set|
|
96
|
-
prm_set.each do |prms|
|
97
|
-
report = perform_cross_validation(x, y, prms)
|
98
|
-
store_cv_result(prms, report)
|
99
|
-
end
|
100
|
-
end
|
101
|
-
|
102
|
-
find_best_params
|
103
|
-
|
104
|
-
@best_estimator = configurated_estimator(@best_params)
|
105
|
-
@best_estimator.fit(x, y)
|
106
|
-
self
|
107
|
-
end
|
108
|
-
|
109
|
-
# Call the decision_function method of learned estimator with the best parameter.
|
110
|
-
#
|
111
|
-
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to compute the scores.
|
112
|
-
# @return [Numo::DFloat] (shape: [n_samples]) Confidence score per sample.
|
113
|
-
def decision_function(x)
|
114
|
-
check_sample_array(x)
|
115
|
-
@best_estimator.decision_function(x)
|
116
|
-
end
|
117
|
-
|
118
|
-
# Call the predict method of learned estimator with the best parameter.
|
119
|
-
#
|
120
|
-
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to obtain prediction result.
|
121
|
-
# @return [Numo::NArray] Predicted results.
|
122
|
-
def predict(x)
|
123
|
-
check_sample_array(x)
|
124
|
-
@best_estimator.predict(x)
|
125
|
-
end
|
126
|
-
|
127
|
-
# Call the predict_log_proba method of learned estimator with the best parameter.
|
128
|
-
#
|
129
|
-
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the log-probailities.
|
130
|
-
# @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted log-probability of each class per sample.
|
131
|
-
def predict_log_proba(x)
|
132
|
-
check_sample_array(x)
|
133
|
-
@best_estimator.predict_log_proba(x)
|
134
|
-
end
|
135
|
-
|
136
|
-
# Call the predict_proba method of learned estimator with the best parameter.
|
137
|
-
#
|
138
|
-
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the probailities.
|
139
|
-
# @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted probability of each class per sample.
|
140
|
-
def predict_proba(x)
|
141
|
-
check_sample_array(x)
|
142
|
-
@best_estimator.predict_proba(x)
|
143
|
-
end
|
144
|
-
|
145
|
-
# Call the score method of learned estimator with the best parameter.
|
146
|
-
#
|
147
|
-
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) Testing data.
|
148
|
-
# @param y [Numo::NArray] (shape: [n_samples, n_outputs]) True target values or labels for testing data.
|
149
|
-
# @return [Float] The score of estimator.
|
150
|
-
def score(x, y)
|
151
|
-
check_sample_array(x)
|
152
|
-
@best_estimator.score(x, y)
|
153
|
-
end
|
154
|
-
|
155
|
-
# Dump marshal data.
|
156
|
-
# @return [Hash] The marshal data about GridSearchCV.
|
157
|
-
def marshal_dump
|
158
|
-
{ params: @params,
|
159
|
-
cv_results: @cv_results,
|
160
|
-
best_score: @best_score,
|
161
|
-
best_params: @best_params,
|
162
|
-
best_index: @best_index,
|
163
|
-
best_estimator: @best_estimator }
|
164
|
-
end
|
165
|
-
|
166
|
-
# Load marshal data.
|
167
|
-
# @return [nil]
|
168
|
-
def marshal_load(obj)
|
169
|
-
@params = obj[:params]
|
170
|
-
@cv_results = obj[:cv_results]
|
171
|
-
@best_score = obj[:best_score]
|
172
|
-
@best_params = obj[:best_params]
|
173
|
-
@best_index = obj[:best_index]
|
174
|
-
@best_estimator = obj[:best_estimator]
|
175
|
-
nil
|
176
|
-
end
|
177
|
-
|
178
|
-
private
|
179
|
-
|
180
|
-
def valid_param_grid(grid)
|
181
|
-
raise TypeError, 'Expect class of param_grid to be Hash or Array' unless grid.is_a?(Hash) || grid.is_a?(Array)
|
182
|
-
grid = [grid] if grid.is_a?(Hash)
|
183
|
-
grid.each do |h|
|
184
|
-
raise TypeError, 'Expect class of elements in param_grid to be Hash' unless h.is_a?(Hash)
|
185
|
-
raise TypeError, 'Expect class of parameter values in param_grid to be Array' unless h.values.all? { |v| v.is_a?(Array) }
|
186
|
-
end
|
187
|
-
grid
|
188
|
-
end
|
189
|
-
|
190
|
-
def param_combinations
|
191
|
-
@param_combinations ||= @params[:param_grid].map do |prm|
|
192
|
-
x = Hash[prm.sort].map { |k, v| [k].product(v) }
|
193
|
-
x[0].product(*x[1...x.size]).map { |v| Hash[v] }
|
194
|
-
end
|
195
|
-
end
|
196
|
-
|
197
|
-
def perform_cross_validation(x, y, prms)
|
198
|
-
est = configurated_estimator(prms)
|
199
|
-
cv = CrossValidation.new(estimator: est, splitter: @params[:splitter],
|
200
|
-
evaluator: @params[:evaluator], return_train_score: true)
|
201
|
-
cv.perform(x, y)
|
202
|
-
end
|
203
|
-
|
204
|
-
def configurated_estimator(prms)
|
205
|
-
estimator = Marshal.load(Marshal.dump(@params[:estimator]))
|
206
|
-
if @params[:estimator].is_a?(SVMKit::Pipeline::Pipeline)
|
207
|
-
prms.each do |k, v|
|
208
|
-
est_name, prm_name = k.to_s.split('__')
|
209
|
-
estimator.steps[est_name.to_sym].params[prm_name.to_sym] = v
|
210
|
-
end
|
211
|
-
else
|
212
|
-
prms.each { |k, v| estimator.params[k] = v }
|
213
|
-
end
|
214
|
-
estimator
|
215
|
-
end
|
216
|
-
|
217
|
-
def init_attrs
|
218
|
-
@cv_results = %i[mean_test_score std_test_score
|
219
|
-
mean_train_score std_train_score
|
220
|
-
mean_fit_time std_fit_time params].map { |v| [v, []] }.to_h
|
221
|
-
@best_score = nil
|
222
|
-
@best_params = nil
|
223
|
-
@best_index = nil
|
224
|
-
@best_estimator = nil
|
225
|
-
end
|
226
|
-
|
227
|
-
def store_cv_result(prms, report)
|
228
|
-
test_scores = Numo::DFloat[*report[:test_score]]
|
229
|
-
train_scores = Numo::DFloat[*report[:train_score]]
|
230
|
-
fit_times = Numo::DFloat[*report[:fit_time]]
|
231
|
-
@cv_results[:mean_test_score].push(test_scores.mean)
|
232
|
-
@cv_results[:std_test_score].push(test_scores.stddev)
|
233
|
-
@cv_results[:mean_train_score].push(train_scores.mean)
|
234
|
-
@cv_results[:std_train_score].push(train_scores.stddev)
|
235
|
-
@cv_results[:mean_fit_time].push(fit_times.mean)
|
236
|
-
@cv_results[:std_fit_time].push(fit_times.stddev)
|
237
|
-
@cv_results[:params].push(prms)
|
238
|
-
end
|
239
|
-
|
240
|
-
def find_best_params
|
241
|
-
@best_score = @params[:greater_is_better] ? @cv_results[:mean_test_score].max : @cv_results[:mean_test_score].min
|
242
|
-
@best_index = @cv_results[:mean_test_score].index(@best_score)
|
243
|
-
@best_params = @cv_results[:params][@best_index]
|
244
|
-
end
|
245
|
-
end
|
246
|
-
end
|
247
|
-
end
|
@@ -1,77 +0,0 @@
|
|
1
|
-
# frozen_string_literal: true
|
2
|
-
|
3
|
-
require 'svmkit/validation'
|
4
|
-
require 'svmkit/base/splitter'
|
5
|
-
|
6
|
-
module SVMKit
|
7
|
-
# This module consists of the classes for model validation techniques.
|
8
|
-
module ModelSelection
|
9
|
-
# KFold is a class that generates the set of data indices for K-fold cross-validation.
|
10
|
-
#
|
11
|
-
# @example
|
12
|
-
# kf = SVMKit::ModelSelection::KFold.new(n_splits: 3, shuffle: true, random_seed: 1)
|
13
|
-
# kf.split(samples, labels).each do |train_ids, test_ids|
|
14
|
-
# train_samples = samples[train_ids, true]
|
15
|
-
# test_samples = samples[test_ids, true]
|
16
|
-
# ...
|
17
|
-
# end
|
18
|
-
#
|
19
|
-
class KFold
|
20
|
-
include Base::Splitter
|
21
|
-
|
22
|
-
# Return the flag indicating whether to shuffle the dataset.
|
23
|
-
# @return [Boolean]
|
24
|
-
attr_reader :shuffle
|
25
|
-
|
26
|
-
# Return the random generator for shuffling the dataset.
|
27
|
-
# @return [Random]
|
28
|
-
attr_reader :rng
|
29
|
-
|
30
|
-
# Create a new data splitter for K-fold cross validation.
|
31
|
-
#
|
32
|
-
# @param n_splits [Integer] The number of folds.
|
33
|
-
# @param shuffle [Boolean] The flag indicating whether to shuffle the dataset.
|
34
|
-
# @param random_seed [Integer] The seed value using to initialize the random generator.
|
35
|
-
def initialize(n_splits: 3, shuffle: false, random_seed: nil)
|
36
|
-
SVMKit::Validation.check_params_integer(n_splits: n_splits)
|
37
|
-
SVMKit::Validation.check_params_boolean(shuffle: shuffle)
|
38
|
-
SVMKit::Validation.check_params_type_or_nil(Integer, random_seed: random_seed)
|
39
|
-
SVMKit::Validation.check_params_positive(n_splits: n_splits)
|
40
|
-
@n_splits = n_splits
|
41
|
-
@shuffle = shuffle
|
42
|
-
@random_seed = random_seed
|
43
|
-
@random_seed ||= srand
|
44
|
-
@rng = Random.new(@random_seed)
|
45
|
-
end
|
46
|
-
|
47
|
-
# Generate data indices for K-fold cross validation.
|
48
|
-
#
|
49
|
-
# @param x [Numo::DFloat] (shape: [n_samples, n_features])
|
50
|
-
# The dataset to be used to generate data indices for K-fold cross validation.
|
51
|
-
# @return [Array] The set of data indices for constructing the training and testing dataset in each fold.
|
52
|
-
def split(x, _y = nil)
|
53
|
-
SVMKit::Validation.check_sample_array(x)
|
54
|
-
# Initialize and check some variables.
|
55
|
-
n_samples, = x.shape
|
56
|
-
unless @n_splits.between?(2, n_samples)
|
57
|
-
raise ArgumentError,
|
58
|
-
'The value of n_splits must be not less than 2 and not more than the number of samples.'
|
59
|
-
end
|
60
|
-
# Splits dataset ids to each fold.
|
61
|
-
dataset_ids = [*0...n_samples]
|
62
|
-
dataset_ids.shuffle!(random: @rng) if @shuffle
|
63
|
-
fold_sets = Array.new(@n_splits) do |n|
|
64
|
-
n_fold_samples = n_samples / @n_splits
|
65
|
-
n_fold_samples += 1 if n < n_samples % @n_splits
|
66
|
-
dataset_ids.shift(n_fold_samples)
|
67
|
-
end
|
68
|
-
# Returns array consisting of the training and testing ids for each fold.
|
69
|
-
Array.new(@n_splits) do |n|
|
70
|
-
train_ids = fold_sets.select.with_index { |_, id| id != n }.flatten
|
71
|
-
test_ids = fold_sets[n]
|
72
|
-
[train_ids, test_ids]
|
73
|
-
end
|
74
|
-
end
|
75
|
-
end
|
76
|
-
end
|
77
|
-
end
|
@@ -1,95 +0,0 @@
|
|
1
|
-
# frozen_string_literal: true
|
2
|
-
|
3
|
-
require 'svmkit/validation'
|
4
|
-
require 'svmkit/base/splitter'
|
5
|
-
|
6
|
-
module SVMKit
|
7
|
-
module ModelSelection
|
8
|
-
# StratifiedKFold is a class that generates the set of data indices for K-fold cross-validation.
|
9
|
-
# The proportion of the number of samples in each class will be almost equal for each fold.
|
10
|
-
#
|
11
|
-
# @example
|
12
|
-
# kf = SVMKit::ModelSelection::StratifiedKFold.new(n_splits: 3, shuffle: true, random_seed: 1)
|
13
|
-
# kf.split(samples, labels).each do |train_ids, test_ids|
|
14
|
-
# train_samples = samples[train_ids, true]
|
15
|
-
# test_samples = samples[test_ids, true]
|
16
|
-
# ...
|
17
|
-
# end
|
18
|
-
#
|
19
|
-
class StratifiedKFold
|
20
|
-
include Base::Splitter
|
21
|
-
|
22
|
-
# Return the flag indicating whether to shuffle the dataset.
|
23
|
-
# @return [Boolean]
|
24
|
-
attr_reader :shuffle
|
25
|
-
|
26
|
-
# Return the random generator for shuffling the dataset.
|
27
|
-
# @return [Random]
|
28
|
-
attr_reader :rng
|
29
|
-
|
30
|
-
# Create a new data splitter for K-fold cross validation.
|
31
|
-
#
|
32
|
-
# @param n_splits [Integer] The number of folds.
|
33
|
-
# @param shuffle [Boolean] The flag indicating whether to shuffle the dataset.
|
34
|
-
# @param random_seed [Integer] The seed value using to initialize the random generator.
|
35
|
-
def initialize(n_splits: 3, shuffle: false, random_seed: nil)
|
36
|
-
SVMKit::Validation.check_params_integer(n_splits: n_splits)
|
37
|
-
SVMKit::Validation.check_params_boolean(shuffle: shuffle)
|
38
|
-
SVMKit::Validation.check_params_type_or_nil(Integer, random_seed: random_seed)
|
39
|
-
SVMKit::Validation.check_params_positive(n_splits: n_splits)
|
40
|
-
@n_splits = n_splits
|
41
|
-
@shuffle = shuffle
|
42
|
-
@random_seed = random_seed
|
43
|
-
@random_seed ||= srand
|
44
|
-
@rng = Random.new(@random_seed)
|
45
|
-
end
|
46
|
-
|
47
|
-
# Generate data indices for stratified K-fold cross validation.
|
48
|
-
#
|
49
|
-
# @param x [Numo::DFloat] (shape: [n_samples, n_features])
|
50
|
-
# The dataset to be used to generate data indices for stratified K-fold cross validation.
|
51
|
-
# This argument exists to unify the interface between the K-fold methods, it is not used in the method.
|
52
|
-
# @param y [Numo::Int32] (shape: [n_samples])
|
53
|
-
# The labels to be used to generate data indices for stratified K-fold cross validation.
|
54
|
-
# @return [Array] The set of data indices for constructing the training and testing dataset in each fold.
|
55
|
-
def split(x, y)
|
56
|
-
SVMKit::Validation.check_sample_array(x)
|
57
|
-
SVMKit::Validation.check_label_array(y)
|
58
|
-
SVMKit::Validation.check_sample_label_size(x, y)
|
59
|
-
# Check the number of samples in each class.
|
60
|
-
unless valid_n_splits?(y)
|
61
|
-
raise ArgumentError,
|
62
|
-
'The value of n_splits must be not less than 2 and not more than the number of samples in each class.'
|
63
|
-
end
|
64
|
-
# Splits dataset ids of each class to each fold.
|
65
|
-
fold_sets_each_class = y.to_a.uniq.map { |label| fold_sets(y, label) }
|
66
|
-
# Returns array consisting of the training and testing ids for each fold.
|
67
|
-
Array.new(@n_splits) { |fold_id| train_test_sets(fold_sets_each_class, fold_id) }
|
68
|
-
end
|
69
|
-
|
70
|
-
private
|
71
|
-
|
72
|
-
def valid_n_splits?(y)
|
73
|
-
y.to_a.uniq.map { |label| y.eq(label).where.size }.all? { |n_samples| @n_splits.between?(2, n_samples) }
|
74
|
-
end
|
75
|
-
|
76
|
-
def fold_sets(y, label)
|
77
|
-
sample_ids = y.eq(label).where.to_a
|
78
|
-
sample_ids.shuffle!(random: @rng) if @shuffle
|
79
|
-
n_samples = sample_ids.size
|
80
|
-
Array.new(@n_splits) do |n|
|
81
|
-
n_fold_samples = n_samples / @n_splits
|
82
|
-
n_fold_samples += 1 if n < n_samples % @n_splits
|
83
|
-
sample_ids.shift(n_fold_samples)
|
84
|
-
end
|
85
|
-
end
|
86
|
-
|
87
|
-
def train_test_sets(fold_sets_each_class, fold_id)
|
88
|
-
train_test_sets_each_class = fold_sets_each_class.map do |folds|
|
89
|
-
folds.partition.with_index { |_, id| id != fold_id }.map(&:flatten)
|
90
|
-
end
|
91
|
-
train_test_sets_each_class.transpose.map(&:flatten)
|
92
|
-
end
|
93
|
-
end
|
94
|
-
end
|
95
|
-
end
|
@@ -1,101 +0,0 @@
|
|
1
|
-
# frozen_string_literal: true
|
2
|
-
|
3
|
-
require 'svmkit/validation'
|
4
|
-
require 'svmkit/base/base_estimator.rb'
|
5
|
-
require 'svmkit/base/classifier.rb'
|
6
|
-
|
7
|
-
module SVMKit
|
8
|
-
# This module consists of the classes that implement multi-class classification strategy.
|
9
|
-
module Multiclass
|
10
|
-
# @note
|
11
|
-
# All classifier in SVMKit support multi-class classifiction since version 0.2.7.
|
12
|
-
# There is no need to explicitly use this class for multiclass classifiction.
|
13
|
-
#
|
14
|
-
# OneVsRestClassifier is a class that implements One-vs-Rest (OvR) strategy for multi-class classification.
|
15
|
-
#
|
16
|
-
# @example
|
17
|
-
# base_estimator = SVMKit::LinearModel::LogisticRegression.new
|
18
|
-
# estimator = SVMKit::Multiclass::OneVsRestClassifier.new(estimator: base_estimator)
|
19
|
-
# estimator.fit(training_samples, training_labels)
|
20
|
-
# results = estimator.predict(testing_samples)
|
21
|
-
class OneVsRestClassifier
|
22
|
-
include Base::BaseEstimator
|
23
|
-
include Base::Classifier
|
24
|
-
|
25
|
-
# Return the set of estimators.
|
26
|
-
# @return [Array<Classifier>]
|
27
|
-
attr_reader :estimators
|
28
|
-
|
29
|
-
# Return the class labels.
|
30
|
-
# @return [Numo::Int32] (shape: [n_classes])
|
31
|
-
attr_reader :classes
|
32
|
-
|
33
|
-
# Create a new multi-class classifier with the one-vs-rest startegy.
|
34
|
-
#
|
35
|
-
# @param estimator [Classifier] The (binary) classifier for construction a multi-class classifier.
|
36
|
-
def initialize(estimator: nil)
|
37
|
-
SVMKit::Validation.check_params_type(SVMKit::Base::BaseEstimator, estimator: estimator)
|
38
|
-
@params = {}
|
39
|
-
@params[:estimator] = estimator
|
40
|
-
@estimators = nil
|
41
|
-
@classes = nil
|
42
|
-
end
|
43
|
-
|
44
|
-
# Fit the model with given training data.
|
45
|
-
#
|
46
|
-
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
|
47
|
-
# @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
|
48
|
-
# @return [OneVsRestClassifier] The learned classifier itself.
|
49
|
-
def fit(x, y)
|
50
|
-
SVMKit::Validation.check_sample_array(x)
|
51
|
-
SVMKit::Validation.check_label_array(y)
|
52
|
-
SVMKit::Validation.check_sample_label_size(x, y)
|
53
|
-
y_arr = y.to_a
|
54
|
-
@classes = Numo::Int32.asarray(y_arr.uniq.sort)
|
55
|
-
@estimators = @classes.to_a.map do |label|
|
56
|
-
bin_y = Numo::Int32.asarray(y_arr.map { |l| l == label ? 1 : -1 })
|
57
|
-
@params[:estimator].dup.fit(x, bin_y)
|
58
|
-
end
|
59
|
-
self
|
60
|
-
end
|
61
|
-
|
62
|
-
# Calculate confidence scores for samples.
|
63
|
-
#
|
64
|
-
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to compute the scores.
|
65
|
-
# @return [Numo::DFloat] (shape: [n_samples, n_classes]) Confidence scores per sample for each class.
|
66
|
-
def decision_function(x)
|
67
|
-
SVMKit::Validation.check_sample_array(x)
|
68
|
-
n_classes = @classes.size
|
69
|
-
Numo::DFloat.asarray(Array.new(n_classes) { |m| @estimators[m].decision_function(x).to_a }).transpose
|
70
|
-
end
|
71
|
-
|
72
|
-
# Predict class labels for samples.
|
73
|
-
#
|
74
|
-
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels.
|
75
|
-
# @return [Numo::Int32] (shape: [n_samples]) Predicted class label per sample.
|
76
|
-
def predict(x)
|
77
|
-
SVMKit::Validation.check_sample_array(x)
|
78
|
-
n_samples, = x.shape
|
79
|
-
decision_values = decision_function(x)
|
80
|
-
Numo::Int32.asarray(Array.new(n_samples) { |n| @classes[decision_values[n, true].max_index] })
|
81
|
-
end
|
82
|
-
|
83
|
-
# Dump marshal data.
|
84
|
-
# @return [Hash] The marshal data about OneVsRestClassifier.
|
85
|
-
def marshal_dump
|
86
|
-
{ params: @params,
|
87
|
-
classes: @classes,
|
88
|
-
estimators: @estimators.map { |e| Marshal.dump(e) } }
|
89
|
-
end
|
90
|
-
|
91
|
-
# Load marshal data.
|
92
|
-
# @return [nil]
|
93
|
-
def marshal_load(obj)
|
94
|
-
@params = obj[:params]
|
95
|
-
@classes = obj[:classes]
|
96
|
-
@estimators = obj[:estimators].map { |e| Marshal.load(e) }
|
97
|
-
nil
|
98
|
-
end
|
99
|
-
end
|
100
|
-
end
|
101
|
-
end
|