svmkit 0.7.3 → 0.8.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +0 -9
  3. data/.rspec +1 -0
  4. data/.travis.yml +4 -12
  5. data/LICENSE.txt +1 -1
  6. data/README.md +11 -13
  7. data/lib/svmkit.rb +3 -66
  8. data/svmkit.gemspec +12 -7
  9. metadata +16 -81
  10. data/.coveralls.yml +0 -1
  11. data/.rubocop.yml +0 -47
  12. data/.rubocop_todo.yml +0 -58
  13. data/HISTORY.md +0 -168
  14. data/lib/svmkit/base/base_estimator.rb +0 -13
  15. data/lib/svmkit/base/classifier.rb +0 -34
  16. data/lib/svmkit/base/cluster_analyzer.rb +0 -29
  17. data/lib/svmkit/base/evaluator.rb +0 -13
  18. data/lib/svmkit/base/regressor.rb +0 -34
  19. data/lib/svmkit/base/splitter.rb +0 -17
  20. data/lib/svmkit/base/transformer.rb +0 -18
  21. data/lib/svmkit/clustering/dbscan.rb +0 -127
  22. data/lib/svmkit/clustering/k_means.rb +0 -140
  23. data/lib/svmkit/dataset.rb +0 -109
  24. data/lib/svmkit/decomposition/nmf.rb +0 -147
  25. data/lib/svmkit/decomposition/pca.rb +0 -150
  26. data/lib/svmkit/ensemble/ada_boost_classifier.rb +0 -198
  27. data/lib/svmkit/ensemble/ada_boost_regressor.rb +0 -180
  28. data/lib/svmkit/ensemble/random_forest_classifier.rb +0 -182
  29. data/lib/svmkit/ensemble/random_forest_regressor.rb +0 -143
  30. data/lib/svmkit/evaluation_measure/accuracy.rb +0 -30
  31. data/lib/svmkit/evaluation_measure/f_score.rb +0 -51
  32. data/lib/svmkit/evaluation_measure/log_loss.rb +0 -46
  33. data/lib/svmkit/evaluation_measure/mean_absolute_error.rb +0 -30
  34. data/lib/svmkit/evaluation_measure/mean_squared_error.rb +0 -30
  35. data/lib/svmkit/evaluation_measure/normalized_mutual_information.rb +0 -63
  36. data/lib/svmkit/evaluation_measure/precision.rb +0 -51
  37. data/lib/svmkit/evaluation_measure/precision_recall.rb +0 -91
  38. data/lib/svmkit/evaluation_measure/purity.rb +0 -41
  39. data/lib/svmkit/evaluation_measure/r2_score.rb +0 -44
  40. data/lib/svmkit/evaluation_measure/recall.rb +0 -51
  41. data/lib/svmkit/kernel_approximation/rbf.rb +0 -136
  42. data/lib/svmkit/kernel_machine/kernel_svc.rb +0 -194
  43. data/lib/svmkit/linear_model/lasso.rb +0 -138
  44. data/lib/svmkit/linear_model/linear_regression.rb +0 -112
  45. data/lib/svmkit/linear_model/logistic_regression.rb +0 -161
  46. data/lib/svmkit/linear_model/ridge.rb +0 -112
  47. data/lib/svmkit/linear_model/sgd_linear_estimator.rb +0 -89
  48. data/lib/svmkit/linear_model/svc.rb +0 -184
  49. data/lib/svmkit/linear_model/svr.rb +0 -123
  50. data/lib/svmkit/model_selection/cross_validation.rb +0 -121
  51. data/lib/svmkit/model_selection/grid_search_cv.rb +0 -247
  52. data/lib/svmkit/model_selection/k_fold.rb +0 -77
  53. data/lib/svmkit/model_selection/stratified_k_fold.rb +0 -95
  54. data/lib/svmkit/multiclass/one_vs_rest_classifier.rb +0 -101
  55. data/lib/svmkit/naive_bayes/naive_bayes.rb +0 -316
  56. data/lib/svmkit/nearest_neighbors/k_neighbors_classifier.rb +0 -112
  57. data/lib/svmkit/nearest_neighbors/k_neighbors_regressor.rb +0 -94
  58. data/lib/svmkit/optimizer/nadam.rb +0 -90
  59. data/lib/svmkit/optimizer/rmsprop.rb +0 -69
  60. data/lib/svmkit/optimizer/sgd.rb +0 -65
  61. data/lib/svmkit/optimizer/yellow_fin.rb +0 -144
  62. data/lib/svmkit/pairwise_metric.rb +0 -91
  63. data/lib/svmkit/pipeline/pipeline.rb +0 -197
  64. data/lib/svmkit/polynomial_model/factorization_machine_classifier.rb +0 -262
  65. data/lib/svmkit/polynomial_model/factorization_machine_regressor.rb +0 -194
  66. data/lib/svmkit/preprocessing/l2_normalizer.rb +0 -63
  67. data/lib/svmkit/preprocessing/label_encoder.rb +0 -95
  68. data/lib/svmkit/preprocessing/min_max_scaler.rb +0 -93
  69. data/lib/svmkit/preprocessing/one_hot_encoder.rb +0 -99
  70. data/lib/svmkit/preprocessing/standard_scaler.rb +0 -87
  71. data/lib/svmkit/probabilistic_output.rb +0 -112
  72. data/lib/svmkit/tree/decision_tree_classifier.rb +0 -276
  73. data/lib/svmkit/tree/decision_tree_regressor.rb +0 -251
  74. data/lib/svmkit/tree/node.rb +0 -70
  75. data/lib/svmkit/utils.rb +0 -22
  76. data/lib/svmkit/validation.rb +0 -79
  77. data/lib/svmkit/values.rb +0 -13
  78. data/lib/svmkit/version.rb +0 -7
@@ -1,247 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'svmkit/validation'
4
- require 'svmkit/base/base_estimator'
5
- require 'svmkit/base/evaluator'
6
- require 'svmkit/base/splitter'
7
- require 'svmkit/pipeline/pipeline'
8
-
9
- module SVMKit
10
- module ModelSelection
11
- # GridSearchCV is a class that performs hyperparameter optimization with grid search method.
12
- #
13
- # @example
14
- # rfc = SVMKit::Ensemble::RandomForestClassifier.new(random_seed: 1)
15
- # pg = { n_estimators: [5, 10], max_depth: [3, 5], max_leaf_nodes: [15, 31] }
16
- # kf = SVMKit::ModelSelection::StratifiedKFold.new(n_splits: 5)
17
- # gs = SVMKit::ModelSelection::GridSearchCV.new(estimator: rfc, param_grid: pg, splitter: kf)
18
- # gs.fit(samples, labels)
19
- # p gs.cv_results
20
- # p gs.best_params
21
- #
22
- # @example
23
- # rbf = SVMKit::KernelApproximation::RBF.new(random_seed: 1)
24
- # svc = SVMKit::LinearModel::SVC.new(random_seed: 1)
25
- # pipe = SVMKit::Pipeline::Pipeline.new(steps: { rbf: rbf, svc: svc })
26
- # pg = { rbf__gamma: [32.0, 1.0], rbf__n_components: [4, 128], svc__reg_param: [16.0, 0.1] }
27
- # kf = SVMKit::ModelSelection::StratifiedKFold.new(n_splits: 5)
28
- # gs = SVMKit::ModelSelection::GridSearchCV.new(estimator: pipe, param_grid: pg, splitter: kf)
29
- # gs.fit(samples, labels)
30
- # p gs.cv_results
31
- # p gs.best_params
32
- #
33
- class GridSearchCV
34
- include Base::BaseEstimator
35
- include Validation
36
-
37
- # Return the result of cross validation for each parameter.
38
- # @return [Hash]
39
- attr_reader :cv_results
40
-
41
- # Return the score of the estimator learned with the best parameter.
42
- # @return [Float]
43
- attr_reader :best_score
44
-
45
- # Return the best parameter set.
46
- # @return [Hash]
47
- attr_reader :best_params
48
-
49
- # Return the index of the best parameter.
50
- # @return [Integer]
51
- attr_reader :best_index
52
-
53
- # Return the estimator learned with the best parameter.
54
- # @return [Estimator]
55
- attr_reader :best_estimator
56
-
57
- # Create a new grid search method.
58
- #
59
- # @param estimator [Classifier/Regresor] The estimator to be searched for optimal parameters with grid search method.
60
- # @param param_grid [Array<Hash>] The parameter sets is represented with array of hash that
61
- # consists of parameter names as keys and array of parameter values as values.
62
- # @param splitter [Splitter] The splitter that divides dataset to training and testing dataset on cross validation.
63
- # @param evaluator [Evaluator] The evaluator that calculates score of estimator results on cross validation.
64
- # If nil is given, the score method of estimator is used to evaluation.
65
- # @param greater_is_better [Boolean] The flag that indicates whether the estimator is better as
66
- # evaluation score is larger.
67
- def initialize(estimator: nil, param_grid: nil, splitter: nil, evaluator: nil, greater_is_better: true)
68
- check_params_type(SVMKit::Base::BaseEstimator, estimator: estimator)
69
- check_params_type(SVMKit::Base::Splitter, splitter: splitter)
70
- check_params_type_or_nil(SVMKit::Base::Evaluator, evaluator: evaluator)
71
- check_params_boolean(greater_is_better: greater_is_better)
72
- @params = {}
73
- @params[:param_grid] = valid_param_grid(param_grid)
74
- @params[:estimator] = Marshal.load(Marshal.dump(estimator))
75
- @params[:splitter] = Marshal.load(Marshal.dump(splitter))
76
- @params[:evaluator] = Marshal.load(Marshal.dump(evaluator))
77
- @params[:greater_is_better] = greater_is_better
78
- @cv_results = nil
79
- @best_score = nil
80
- @best_params = nil
81
- @best_index = nil
82
- @best_estimator = nil
83
- end
84
-
85
- # Fit the model with given training data and all sets of parameters.
86
- #
87
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
88
- # @param y [Numo::NArray] (shape: [n_samples, n_outputs]) The target values or labels to be used for fitting the model.
89
- # @return [GridSearchCV] The learned estimator with grid search.
90
- def fit(x, y)
91
- check_sample_array(x)
92
-
93
- init_attrs
94
-
95
- param_combinations.each do |prm_set|
96
- prm_set.each do |prms|
97
- report = perform_cross_validation(x, y, prms)
98
- store_cv_result(prms, report)
99
- end
100
- end
101
-
102
- find_best_params
103
-
104
- @best_estimator = configurated_estimator(@best_params)
105
- @best_estimator.fit(x, y)
106
- self
107
- end
108
-
109
- # Call the decision_function method of learned estimator with the best parameter.
110
- #
111
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to compute the scores.
112
- # @return [Numo::DFloat] (shape: [n_samples]) Confidence score per sample.
113
- def decision_function(x)
114
- check_sample_array(x)
115
- @best_estimator.decision_function(x)
116
- end
117
-
118
- # Call the predict method of learned estimator with the best parameter.
119
- #
120
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to obtain prediction result.
121
- # @return [Numo::NArray] Predicted results.
122
- def predict(x)
123
- check_sample_array(x)
124
- @best_estimator.predict(x)
125
- end
126
-
127
- # Call the predict_log_proba method of learned estimator with the best parameter.
128
- #
129
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the log-probailities.
130
- # @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted log-probability of each class per sample.
131
- def predict_log_proba(x)
132
- check_sample_array(x)
133
- @best_estimator.predict_log_proba(x)
134
- end
135
-
136
- # Call the predict_proba method of learned estimator with the best parameter.
137
- #
138
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the probailities.
139
- # @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted probability of each class per sample.
140
- def predict_proba(x)
141
- check_sample_array(x)
142
- @best_estimator.predict_proba(x)
143
- end
144
-
145
- # Call the score method of learned estimator with the best parameter.
146
- #
147
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) Testing data.
148
- # @param y [Numo::NArray] (shape: [n_samples, n_outputs]) True target values or labels for testing data.
149
- # @return [Float] The score of estimator.
150
- def score(x, y)
151
- check_sample_array(x)
152
- @best_estimator.score(x, y)
153
- end
154
-
155
- # Dump marshal data.
156
- # @return [Hash] The marshal data about GridSearchCV.
157
- def marshal_dump
158
- { params: @params,
159
- cv_results: @cv_results,
160
- best_score: @best_score,
161
- best_params: @best_params,
162
- best_index: @best_index,
163
- best_estimator: @best_estimator }
164
- end
165
-
166
- # Load marshal data.
167
- # @return [nil]
168
- def marshal_load(obj)
169
- @params = obj[:params]
170
- @cv_results = obj[:cv_results]
171
- @best_score = obj[:best_score]
172
- @best_params = obj[:best_params]
173
- @best_index = obj[:best_index]
174
- @best_estimator = obj[:best_estimator]
175
- nil
176
- end
177
-
178
- private
179
-
180
- def valid_param_grid(grid)
181
- raise TypeError, 'Expect class of param_grid to be Hash or Array' unless grid.is_a?(Hash) || grid.is_a?(Array)
182
- grid = [grid] if grid.is_a?(Hash)
183
- grid.each do |h|
184
- raise TypeError, 'Expect class of elements in param_grid to be Hash' unless h.is_a?(Hash)
185
- raise TypeError, 'Expect class of parameter values in param_grid to be Array' unless h.values.all? { |v| v.is_a?(Array) }
186
- end
187
- grid
188
- end
189
-
190
- def param_combinations
191
- @param_combinations ||= @params[:param_grid].map do |prm|
192
- x = Hash[prm.sort].map { |k, v| [k].product(v) }
193
- x[0].product(*x[1...x.size]).map { |v| Hash[v] }
194
- end
195
- end
196
-
197
- def perform_cross_validation(x, y, prms)
198
- est = configurated_estimator(prms)
199
- cv = CrossValidation.new(estimator: est, splitter: @params[:splitter],
200
- evaluator: @params[:evaluator], return_train_score: true)
201
- cv.perform(x, y)
202
- end
203
-
204
- def configurated_estimator(prms)
205
- estimator = Marshal.load(Marshal.dump(@params[:estimator]))
206
- if @params[:estimator].is_a?(SVMKit::Pipeline::Pipeline)
207
- prms.each do |k, v|
208
- est_name, prm_name = k.to_s.split('__')
209
- estimator.steps[est_name.to_sym].params[prm_name.to_sym] = v
210
- end
211
- else
212
- prms.each { |k, v| estimator.params[k] = v }
213
- end
214
- estimator
215
- end
216
-
217
- def init_attrs
218
- @cv_results = %i[mean_test_score std_test_score
219
- mean_train_score std_train_score
220
- mean_fit_time std_fit_time params].map { |v| [v, []] }.to_h
221
- @best_score = nil
222
- @best_params = nil
223
- @best_index = nil
224
- @best_estimator = nil
225
- end
226
-
227
- def store_cv_result(prms, report)
228
- test_scores = Numo::DFloat[*report[:test_score]]
229
- train_scores = Numo::DFloat[*report[:train_score]]
230
- fit_times = Numo::DFloat[*report[:fit_time]]
231
- @cv_results[:mean_test_score].push(test_scores.mean)
232
- @cv_results[:std_test_score].push(test_scores.stddev)
233
- @cv_results[:mean_train_score].push(train_scores.mean)
234
- @cv_results[:std_train_score].push(train_scores.stddev)
235
- @cv_results[:mean_fit_time].push(fit_times.mean)
236
- @cv_results[:std_fit_time].push(fit_times.stddev)
237
- @cv_results[:params].push(prms)
238
- end
239
-
240
- def find_best_params
241
- @best_score = @params[:greater_is_better] ? @cv_results[:mean_test_score].max : @cv_results[:mean_test_score].min
242
- @best_index = @cv_results[:mean_test_score].index(@best_score)
243
- @best_params = @cv_results[:params][@best_index]
244
- end
245
- end
246
- end
247
- end
@@ -1,77 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'svmkit/validation'
4
- require 'svmkit/base/splitter'
5
-
6
- module SVMKit
7
- # This module consists of the classes for model validation techniques.
8
- module ModelSelection
9
- # KFold is a class that generates the set of data indices for K-fold cross-validation.
10
- #
11
- # @example
12
- # kf = SVMKit::ModelSelection::KFold.new(n_splits: 3, shuffle: true, random_seed: 1)
13
- # kf.split(samples, labels).each do |train_ids, test_ids|
14
- # train_samples = samples[train_ids, true]
15
- # test_samples = samples[test_ids, true]
16
- # ...
17
- # end
18
- #
19
- class KFold
20
- include Base::Splitter
21
-
22
- # Return the flag indicating whether to shuffle the dataset.
23
- # @return [Boolean]
24
- attr_reader :shuffle
25
-
26
- # Return the random generator for shuffling the dataset.
27
- # @return [Random]
28
- attr_reader :rng
29
-
30
- # Create a new data splitter for K-fold cross validation.
31
- #
32
- # @param n_splits [Integer] The number of folds.
33
- # @param shuffle [Boolean] The flag indicating whether to shuffle the dataset.
34
- # @param random_seed [Integer] The seed value using to initialize the random generator.
35
- def initialize(n_splits: 3, shuffle: false, random_seed: nil)
36
- SVMKit::Validation.check_params_integer(n_splits: n_splits)
37
- SVMKit::Validation.check_params_boolean(shuffle: shuffle)
38
- SVMKit::Validation.check_params_type_or_nil(Integer, random_seed: random_seed)
39
- SVMKit::Validation.check_params_positive(n_splits: n_splits)
40
- @n_splits = n_splits
41
- @shuffle = shuffle
42
- @random_seed = random_seed
43
- @random_seed ||= srand
44
- @rng = Random.new(@random_seed)
45
- end
46
-
47
- # Generate data indices for K-fold cross validation.
48
- #
49
- # @param x [Numo::DFloat] (shape: [n_samples, n_features])
50
- # The dataset to be used to generate data indices for K-fold cross validation.
51
- # @return [Array] The set of data indices for constructing the training and testing dataset in each fold.
52
- def split(x, _y = nil)
53
- SVMKit::Validation.check_sample_array(x)
54
- # Initialize and check some variables.
55
- n_samples, = x.shape
56
- unless @n_splits.between?(2, n_samples)
57
- raise ArgumentError,
58
- 'The value of n_splits must be not less than 2 and not more than the number of samples.'
59
- end
60
- # Splits dataset ids to each fold.
61
- dataset_ids = [*0...n_samples]
62
- dataset_ids.shuffle!(random: @rng) if @shuffle
63
- fold_sets = Array.new(@n_splits) do |n|
64
- n_fold_samples = n_samples / @n_splits
65
- n_fold_samples += 1 if n < n_samples % @n_splits
66
- dataset_ids.shift(n_fold_samples)
67
- end
68
- # Returns array consisting of the training and testing ids for each fold.
69
- Array.new(@n_splits) do |n|
70
- train_ids = fold_sets.select.with_index { |_, id| id != n }.flatten
71
- test_ids = fold_sets[n]
72
- [train_ids, test_ids]
73
- end
74
- end
75
- end
76
- end
77
- end
@@ -1,95 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'svmkit/validation'
4
- require 'svmkit/base/splitter'
5
-
6
- module SVMKit
7
- module ModelSelection
8
- # StratifiedKFold is a class that generates the set of data indices for K-fold cross-validation.
9
- # The proportion of the number of samples in each class will be almost equal for each fold.
10
- #
11
- # @example
12
- # kf = SVMKit::ModelSelection::StratifiedKFold.new(n_splits: 3, shuffle: true, random_seed: 1)
13
- # kf.split(samples, labels).each do |train_ids, test_ids|
14
- # train_samples = samples[train_ids, true]
15
- # test_samples = samples[test_ids, true]
16
- # ...
17
- # end
18
- #
19
- class StratifiedKFold
20
- include Base::Splitter
21
-
22
- # Return the flag indicating whether to shuffle the dataset.
23
- # @return [Boolean]
24
- attr_reader :shuffle
25
-
26
- # Return the random generator for shuffling the dataset.
27
- # @return [Random]
28
- attr_reader :rng
29
-
30
- # Create a new data splitter for K-fold cross validation.
31
- #
32
- # @param n_splits [Integer] The number of folds.
33
- # @param shuffle [Boolean] The flag indicating whether to shuffle the dataset.
34
- # @param random_seed [Integer] The seed value using to initialize the random generator.
35
- def initialize(n_splits: 3, shuffle: false, random_seed: nil)
36
- SVMKit::Validation.check_params_integer(n_splits: n_splits)
37
- SVMKit::Validation.check_params_boolean(shuffle: shuffle)
38
- SVMKit::Validation.check_params_type_or_nil(Integer, random_seed: random_seed)
39
- SVMKit::Validation.check_params_positive(n_splits: n_splits)
40
- @n_splits = n_splits
41
- @shuffle = shuffle
42
- @random_seed = random_seed
43
- @random_seed ||= srand
44
- @rng = Random.new(@random_seed)
45
- end
46
-
47
- # Generate data indices for stratified K-fold cross validation.
48
- #
49
- # @param x [Numo::DFloat] (shape: [n_samples, n_features])
50
- # The dataset to be used to generate data indices for stratified K-fold cross validation.
51
- # This argument exists to unify the interface between the K-fold methods, it is not used in the method.
52
- # @param y [Numo::Int32] (shape: [n_samples])
53
- # The labels to be used to generate data indices for stratified K-fold cross validation.
54
- # @return [Array] The set of data indices for constructing the training and testing dataset in each fold.
55
- def split(x, y)
56
- SVMKit::Validation.check_sample_array(x)
57
- SVMKit::Validation.check_label_array(y)
58
- SVMKit::Validation.check_sample_label_size(x, y)
59
- # Check the number of samples in each class.
60
- unless valid_n_splits?(y)
61
- raise ArgumentError,
62
- 'The value of n_splits must be not less than 2 and not more than the number of samples in each class.'
63
- end
64
- # Splits dataset ids of each class to each fold.
65
- fold_sets_each_class = y.to_a.uniq.map { |label| fold_sets(y, label) }
66
- # Returns array consisting of the training and testing ids for each fold.
67
- Array.new(@n_splits) { |fold_id| train_test_sets(fold_sets_each_class, fold_id) }
68
- end
69
-
70
- private
71
-
72
- def valid_n_splits?(y)
73
- y.to_a.uniq.map { |label| y.eq(label).where.size }.all? { |n_samples| @n_splits.between?(2, n_samples) }
74
- end
75
-
76
- def fold_sets(y, label)
77
- sample_ids = y.eq(label).where.to_a
78
- sample_ids.shuffle!(random: @rng) if @shuffle
79
- n_samples = sample_ids.size
80
- Array.new(@n_splits) do |n|
81
- n_fold_samples = n_samples / @n_splits
82
- n_fold_samples += 1 if n < n_samples % @n_splits
83
- sample_ids.shift(n_fold_samples)
84
- end
85
- end
86
-
87
- def train_test_sets(fold_sets_each_class, fold_id)
88
- train_test_sets_each_class = fold_sets_each_class.map do |folds|
89
- folds.partition.with_index { |_, id| id != fold_id }.map(&:flatten)
90
- end
91
- train_test_sets_each_class.transpose.map(&:flatten)
92
- end
93
- end
94
- end
95
- end
@@ -1,101 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'svmkit/validation'
4
- require 'svmkit/base/base_estimator.rb'
5
- require 'svmkit/base/classifier.rb'
6
-
7
- module SVMKit
8
- # This module consists of the classes that implement multi-class classification strategy.
9
- module Multiclass
10
- # @note
11
- # All classifier in SVMKit support multi-class classifiction since version 0.2.7.
12
- # There is no need to explicitly use this class for multiclass classifiction.
13
- #
14
- # OneVsRestClassifier is a class that implements One-vs-Rest (OvR) strategy for multi-class classification.
15
- #
16
- # @example
17
- # base_estimator = SVMKit::LinearModel::LogisticRegression.new
18
- # estimator = SVMKit::Multiclass::OneVsRestClassifier.new(estimator: base_estimator)
19
- # estimator.fit(training_samples, training_labels)
20
- # results = estimator.predict(testing_samples)
21
- class OneVsRestClassifier
22
- include Base::BaseEstimator
23
- include Base::Classifier
24
-
25
- # Return the set of estimators.
26
- # @return [Array<Classifier>]
27
- attr_reader :estimators
28
-
29
- # Return the class labels.
30
- # @return [Numo::Int32] (shape: [n_classes])
31
- attr_reader :classes
32
-
33
- # Create a new multi-class classifier with the one-vs-rest startegy.
34
- #
35
- # @param estimator [Classifier] The (binary) classifier for construction a multi-class classifier.
36
- def initialize(estimator: nil)
37
- SVMKit::Validation.check_params_type(SVMKit::Base::BaseEstimator, estimator: estimator)
38
- @params = {}
39
- @params[:estimator] = estimator
40
- @estimators = nil
41
- @classes = nil
42
- end
43
-
44
- # Fit the model with given training data.
45
- #
46
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
47
- # @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
48
- # @return [OneVsRestClassifier] The learned classifier itself.
49
- def fit(x, y)
50
- SVMKit::Validation.check_sample_array(x)
51
- SVMKit::Validation.check_label_array(y)
52
- SVMKit::Validation.check_sample_label_size(x, y)
53
- y_arr = y.to_a
54
- @classes = Numo::Int32.asarray(y_arr.uniq.sort)
55
- @estimators = @classes.to_a.map do |label|
56
- bin_y = Numo::Int32.asarray(y_arr.map { |l| l == label ? 1 : -1 })
57
- @params[:estimator].dup.fit(x, bin_y)
58
- end
59
- self
60
- end
61
-
62
- # Calculate confidence scores for samples.
63
- #
64
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to compute the scores.
65
- # @return [Numo::DFloat] (shape: [n_samples, n_classes]) Confidence scores per sample for each class.
66
- def decision_function(x)
67
- SVMKit::Validation.check_sample_array(x)
68
- n_classes = @classes.size
69
- Numo::DFloat.asarray(Array.new(n_classes) { |m| @estimators[m].decision_function(x).to_a }).transpose
70
- end
71
-
72
- # Predict class labels for samples.
73
- #
74
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels.
75
- # @return [Numo::Int32] (shape: [n_samples]) Predicted class label per sample.
76
- def predict(x)
77
- SVMKit::Validation.check_sample_array(x)
78
- n_samples, = x.shape
79
- decision_values = decision_function(x)
80
- Numo::Int32.asarray(Array.new(n_samples) { |n| @classes[decision_values[n, true].max_index] })
81
- end
82
-
83
- # Dump marshal data.
84
- # @return [Hash] The marshal data about OneVsRestClassifier.
85
- def marshal_dump
86
- { params: @params,
87
- classes: @classes,
88
- estimators: @estimators.map { |e| Marshal.dump(e) } }
89
- end
90
-
91
- # Load marshal data.
92
- # @return [nil]
93
- def marshal_load(obj)
94
- @params = obj[:params]
95
- @classes = obj[:classes]
96
- @estimators = obj[:estimators].map { |e| Marshal.load(e) }
97
- nil
98
- end
99
- end
100
- end
101
- end