svmkit 0.7.3 → 0.8.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (78) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +0 -9
  3. data/.rspec +1 -0
  4. data/.travis.yml +4 -12
  5. data/LICENSE.txt +1 -1
  6. data/README.md +11 -13
  7. data/lib/svmkit.rb +3 -66
  8. data/svmkit.gemspec +12 -7
  9. metadata +16 -81
  10. data/.coveralls.yml +0 -1
  11. data/.rubocop.yml +0 -47
  12. data/.rubocop_todo.yml +0 -58
  13. data/HISTORY.md +0 -168
  14. data/lib/svmkit/base/base_estimator.rb +0 -13
  15. data/lib/svmkit/base/classifier.rb +0 -34
  16. data/lib/svmkit/base/cluster_analyzer.rb +0 -29
  17. data/lib/svmkit/base/evaluator.rb +0 -13
  18. data/lib/svmkit/base/regressor.rb +0 -34
  19. data/lib/svmkit/base/splitter.rb +0 -17
  20. data/lib/svmkit/base/transformer.rb +0 -18
  21. data/lib/svmkit/clustering/dbscan.rb +0 -127
  22. data/lib/svmkit/clustering/k_means.rb +0 -140
  23. data/lib/svmkit/dataset.rb +0 -109
  24. data/lib/svmkit/decomposition/nmf.rb +0 -147
  25. data/lib/svmkit/decomposition/pca.rb +0 -150
  26. data/lib/svmkit/ensemble/ada_boost_classifier.rb +0 -198
  27. data/lib/svmkit/ensemble/ada_boost_regressor.rb +0 -180
  28. data/lib/svmkit/ensemble/random_forest_classifier.rb +0 -182
  29. data/lib/svmkit/ensemble/random_forest_regressor.rb +0 -143
  30. data/lib/svmkit/evaluation_measure/accuracy.rb +0 -30
  31. data/lib/svmkit/evaluation_measure/f_score.rb +0 -51
  32. data/lib/svmkit/evaluation_measure/log_loss.rb +0 -46
  33. data/lib/svmkit/evaluation_measure/mean_absolute_error.rb +0 -30
  34. data/lib/svmkit/evaluation_measure/mean_squared_error.rb +0 -30
  35. data/lib/svmkit/evaluation_measure/normalized_mutual_information.rb +0 -63
  36. data/lib/svmkit/evaluation_measure/precision.rb +0 -51
  37. data/lib/svmkit/evaluation_measure/precision_recall.rb +0 -91
  38. data/lib/svmkit/evaluation_measure/purity.rb +0 -41
  39. data/lib/svmkit/evaluation_measure/r2_score.rb +0 -44
  40. data/lib/svmkit/evaluation_measure/recall.rb +0 -51
  41. data/lib/svmkit/kernel_approximation/rbf.rb +0 -136
  42. data/lib/svmkit/kernel_machine/kernel_svc.rb +0 -194
  43. data/lib/svmkit/linear_model/lasso.rb +0 -138
  44. data/lib/svmkit/linear_model/linear_regression.rb +0 -112
  45. data/lib/svmkit/linear_model/logistic_regression.rb +0 -161
  46. data/lib/svmkit/linear_model/ridge.rb +0 -112
  47. data/lib/svmkit/linear_model/sgd_linear_estimator.rb +0 -89
  48. data/lib/svmkit/linear_model/svc.rb +0 -184
  49. data/lib/svmkit/linear_model/svr.rb +0 -123
  50. data/lib/svmkit/model_selection/cross_validation.rb +0 -121
  51. data/lib/svmkit/model_selection/grid_search_cv.rb +0 -247
  52. data/lib/svmkit/model_selection/k_fold.rb +0 -77
  53. data/lib/svmkit/model_selection/stratified_k_fold.rb +0 -95
  54. data/lib/svmkit/multiclass/one_vs_rest_classifier.rb +0 -101
  55. data/lib/svmkit/naive_bayes/naive_bayes.rb +0 -316
  56. data/lib/svmkit/nearest_neighbors/k_neighbors_classifier.rb +0 -112
  57. data/lib/svmkit/nearest_neighbors/k_neighbors_regressor.rb +0 -94
  58. data/lib/svmkit/optimizer/nadam.rb +0 -90
  59. data/lib/svmkit/optimizer/rmsprop.rb +0 -69
  60. data/lib/svmkit/optimizer/sgd.rb +0 -65
  61. data/lib/svmkit/optimizer/yellow_fin.rb +0 -144
  62. data/lib/svmkit/pairwise_metric.rb +0 -91
  63. data/lib/svmkit/pipeline/pipeline.rb +0 -197
  64. data/lib/svmkit/polynomial_model/factorization_machine_classifier.rb +0 -262
  65. data/lib/svmkit/polynomial_model/factorization_machine_regressor.rb +0 -194
  66. data/lib/svmkit/preprocessing/l2_normalizer.rb +0 -63
  67. data/lib/svmkit/preprocessing/label_encoder.rb +0 -95
  68. data/lib/svmkit/preprocessing/min_max_scaler.rb +0 -93
  69. data/lib/svmkit/preprocessing/one_hot_encoder.rb +0 -99
  70. data/lib/svmkit/preprocessing/standard_scaler.rb +0 -87
  71. data/lib/svmkit/probabilistic_output.rb +0 -112
  72. data/lib/svmkit/tree/decision_tree_classifier.rb +0 -276
  73. data/lib/svmkit/tree/decision_tree_regressor.rb +0 -251
  74. data/lib/svmkit/tree/node.rb +0 -70
  75. data/lib/svmkit/utils.rb +0 -22
  76. data/lib/svmkit/validation.rb +0 -79
  77. data/lib/svmkit/values.rb +0 -13
  78. data/lib/svmkit/version.rb +0 -7
@@ -1,316 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'svmkit/validation'
4
- require 'svmkit/base/base_estimator'
5
- require 'svmkit/base/classifier'
6
-
7
- module SVMKit
8
- # This module consists of the classes that implement naive bayes models.
9
- module NaiveBayes
10
- # BaseNaiveBayes is a class that has methods for common processes of naive bayes classifier.
11
- class BaseNaiveBayes
12
- include Base::BaseEstimator
13
- include Base::Classifier
14
-
15
- # Predict class labels for samples.
16
- #
17
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels.
18
- # @return [Numo::Int32] (shape: [n_samples]) Predicted class label per sample.
19
- def predict(x)
20
- SVMKit::Validation.check_sample_array(x)
21
- n_samples = x.shape.first
22
- decision_values = decision_function(x)
23
- Numo::Int32.asarray(Array.new(n_samples) { |n| @classes[decision_values[n, true].max_index] })
24
- end
25
-
26
- # Predict log-probability for samples.
27
- #
28
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the log-probailities.
29
- # @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted log-probability of each class per sample.
30
- def predict_log_proba(x)
31
- SVMKit::Validation.check_sample_array(x)
32
- n_samples, = x.shape
33
- log_likelihoods = decision_function(x)
34
- log_likelihoods - Numo::NMath.log(Numo::NMath.exp(log_likelihoods).sum(1)).reshape(n_samples, 1)
35
- end
36
-
37
- # Predict probability for samples.
38
- #
39
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the probailities.
40
- # @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted probability of each class per sample.
41
- def predict_proba(x)
42
- SVMKit::Validation.check_sample_array(x)
43
- Numo::NMath.exp(predict_log_proba(x)).abs
44
- end
45
- end
46
-
47
- # GaussianNB is a class that implements Gaussian Naive Bayes classifier.
48
- #
49
- # @example
50
- # estimator = SVMKit::NaiveBayes::GaussianNB.new
51
- # estimator.fit(training_samples, training_labels)
52
- # results = estimator.predict(testing_samples)
53
- class GaussianNB < BaseNaiveBayes
54
- # Return the class labels.
55
- # @return [Numo::Int32] (size: n_classes)
56
- attr_reader :classes
57
-
58
- # Return the prior probabilities of the classes.
59
- # @return [Numo::DFloat] (shape: [n_classes])
60
- attr_reader :class_priors
61
-
62
- # Return the mean vectors of the classes.
63
- # @return [Numo::DFloat] (shape: [n_classes, n_features])
64
- attr_reader :means
65
-
66
- # Return the variance vectors of the classes.
67
- # @return [Numo::DFloat] (shape: [n_classes, n_features])
68
- attr_reader :variances
69
-
70
- # Create a new classifier with Gaussian Naive Bayes.
71
- def initialize
72
- @params = {}
73
- end
74
-
75
- # Fit the model with given training data.
76
- #
77
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
78
- # @param y [Numo::Int32] (shape: [n_samples]) The categorical variables (e.g. labels)
79
- # to be used for fitting the model.
80
- # @return [GaussianNB] The learned classifier itself.
81
- def fit(x, y)
82
- SVMKit::Validation.check_sample_array(x)
83
- SVMKit::Validation.check_label_array(y)
84
- SVMKit::Validation.check_sample_label_size(x, y)
85
- n_samples, = x.shape
86
- @classes = Numo::Int32[*y.to_a.uniq.sort]
87
- @class_priors = Numo::DFloat[*@classes.to_a.map { |l| y.eq(l).count / n_samples.to_f }]
88
- @means = Numo::DFloat[*@classes.to_a.map { |l| x[y.eq(l).where, true].mean(0) }]
89
- @variances = Numo::DFloat[*@classes.to_a.map { |l| x[y.eq(l).where, true].var(0) }]
90
- self
91
- end
92
-
93
- # Calculate confidence scores for samples.
94
- #
95
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to compute the scores.
96
- # @return [Numo::DFloat] (shape: [n_samples, n_classes]) Confidence scores per sample for each class.
97
- def decision_function(x)
98
- SVMKit::Validation.check_sample_array(x)
99
- n_classes = @classes.size
100
- log_likelihoods = Array.new(n_classes) do |l|
101
- Math.log(@class_priors[l]) - 0.5 * (
102
- Numo::NMath.log(2.0 * Math::PI * @variances[l, true]) +
103
- ((x - @means[l, true])**2 / @variances[l, true])).sum(1)
104
- end
105
- Numo::DFloat[*log_likelihoods].transpose
106
- end
107
-
108
- # Dump marshal data.
109
- #
110
- # @return [Hash] The marshal data about GaussianNB.
111
- def marshal_dump
112
- { params: @params,
113
- classes: @classes,
114
- class_priors: @class_priors,
115
- means: @means,
116
- variances: @variances }
117
- end
118
-
119
- # Load marshal data.
120
- #
121
- # @return [nil]
122
- def marshal_load(obj)
123
- @params = obj[:params]
124
- @classes = obj[:classes]
125
- @class_priors = obj[:class_priors]
126
- @means = obj[:means]
127
- @variances = obj[:variances]
128
- nil
129
- end
130
- end
131
-
132
- # MultinomialNB is a class that implements Multinomial Naive Bayes classifier.
133
- #
134
- # @example
135
- # estimator = SVMKit::NaiveBayes::MultinomialNB.new(smoothing_param: 1.0)
136
- # estimator.fit(training_samples, training_labels)
137
- # results = estimator.predict(testing_samples)
138
- #
139
- # *Reference*
140
- # - C D. Manning, P. Raghavan, and H. Schutze, "Introduction to Information Retrieval," Cambridge University Press., 2008.
141
- class MultinomialNB < BaseNaiveBayes
142
- # Return the class labels.
143
- # @return [Numo::Int32] (size: n_classes)
144
- attr_reader :classes
145
-
146
- # Return the prior probabilities of the classes.
147
- # @return [Numo::DFloat] (shape: [n_classes])
148
- attr_reader :class_priors
149
-
150
- # Return the conditional probabilities for features of each class.
151
- # @return [Numo::DFloat] (shape: [n_classes, n_features])
152
- attr_reader :feature_probs
153
-
154
- # Create a new classifier with Multinomial Naive Bayes.
155
- #
156
- # @param smoothing_param [Float] The Laplace smoothing parameter.
157
- def initialize(smoothing_param: 1.0)
158
- SVMKit::Validation.check_params_float(smoothing_param: smoothing_param)
159
- SVMKit::Validation.check_params_positive(smoothing_param: smoothing_param)
160
- @params = {}
161
- @params[:smoothing_param] = smoothing_param
162
- end
163
-
164
- # Fit the model with given training data.
165
- #
166
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
167
- # @param y [Numo::Int32] (shape: [n_samples]) The categorical variables (e.g. labels)
168
- # to be used for fitting the model.
169
- # @return [MultinomialNB] The learned classifier itself.
170
- def fit(x, y)
171
- SVMKit::Validation.check_sample_array(x)
172
- SVMKit::Validation.check_label_array(y)
173
- SVMKit::Validation.check_sample_label_size(x, y)
174
- n_samples, = x.shape
175
- @classes = Numo::Int32[*y.to_a.uniq.sort]
176
- @class_priors = Numo::DFloat[*@classes.to_a.map { |l| y.eq(l).count / n_samples.to_f }]
177
- count_features = Numo::DFloat[*@classes.to_a.map { |l| x[y.eq(l).where, true].sum(0) }]
178
- count_features += @params[:smoothing_param]
179
- n_classes = @classes.size
180
- @feature_probs = count_features / count_features.sum(1).reshape(n_classes, 1)
181
- self
182
- end
183
-
184
- # Calculate confidence scores for samples.
185
- #
186
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to compute the scores.
187
- # @return [Numo::DFloat] (shape: [n_samples, n_classes]) Confidence scores per sample for each class.
188
- def decision_function(x)
189
- SVMKit::Validation.check_sample_array(x)
190
- n_classes = @classes.size
191
- bin_x = x.gt(0)
192
- log_likelihoods = Array.new(n_classes) do |l|
193
- Math.log(@class_priors[l]) + (Numo::DFloat[*bin_x] * Numo::NMath.log(@feature_probs[l, true])).sum(1)
194
- end
195
- Numo::DFloat[*log_likelihoods].transpose
196
- end
197
-
198
- # Dump marshal data.
199
- #
200
- # @return [Hash] The marshal data about MultinomialNB.
201
- def marshal_dump
202
- { params: @params,
203
- classes: @classes,
204
- class_priors: @class_priors,
205
- feature_probs: @feature_probs }
206
- end
207
-
208
- # Load marshal data.
209
- #
210
- # @return [nil]
211
- def marshal_load(obj)
212
- @params = obj[:params]
213
- @classes = obj[:classes]
214
- @class_priors = obj[:class_priors]
215
- @feature_probs = obj[:feature_probs]
216
- nil
217
- end
218
- end
219
-
220
- # BernoulliNB is a class that implements Bernoulli Naive Bayes classifier.
221
- #
222
- # @example
223
- # estimator = SVMKit::NaiveBayes::BernoulliNB.new(smoothing_param: 1.0, bin_threshold: 0.0)
224
- # estimator.fit(training_samples, training_labels)
225
- # results = estimator.predict(testing_samples)
226
- #
227
- # *Reference*
228
- # - C D. Manning, P. Raghavan, and H. Schutze, "Introduction to Information Retrieval," Cambridge University Press., 2008.
229
- class BernoulliNB < BaseNaiveBayes
230
- # Return the class labels.
231
- # @return [Numo::Int32] (size: n_classes)
232
- attr_reader :classes
233
-
234
- # Return the prior probabilities of the classes.
235
- # @return [Numo::DFloat] (shape: [n_classes])
236
- attr_reader :class_priors
237
-
238
- # Return the conditional probabilities for features of each class.
239
- # @return [Numo::DFloat] (shape: [n_classes, n_features])
240
- attr_reader :feature_probs
241
-
242
- # Create a new classifier with Bernoulli Naive Bayes.
243
- #
244
- # @param smoothing_param [Float] The Laplace smoothing parameter.
245
- # @param bin_threshold [Float] The threshold for binarizing of features.
246
- def initialize(smoothing_param: 1.0, bin_threshold: 0.0)
247
- SVMKit::Validation.check_params_float(smoothing_param: smoothing_param, bin_threshold: bin_threshold)
248
- SVMKit::Validation.check_params_positive(smoothing_param: smoothing_param)
249
- @params = {}
250
- @params[:smoothing_param] = smoothing_param
251
- @params[:bin_threshold] = bin_threshold
252
- end
253
-
254
- # Fit the model with given training data.
255
- #
256
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
257
- # @param y [Numo::Int32] (shape: [n_samples]) The categorical variables (e.g. labels)
258
- # to be used for fitting the model.
259
- # @return [BernoulliNB] The learned classifier itself.
260
- def fit(x, y)
261
- SVMKit::Validation.check_sample_array(x)
262
- SVMKit::Validation.check_label_array(y)
263
- SVMKit::Validation.check_sample_label_size(x, y)
264
- n_samples, = x.shape
265
- bin_x = Numo::DFloat[*x.gt(@params[:bin_threshold])]
266
- @classes = Numo::Int32[*y.to_a.uniq.sort]
267
- n_samples_each_class = Numo::DFloat[*@classes.to_a.map { |l| y.eq(l).count.to_f }]
268
- @class_priors = n_samples_each_class / n_samples
269
- count_features = Numo::DFloat[*@classes.to_a.map { |l| bin_x[y.eq(l).where, true].sum(0) }]
270
- count_features += @params[:smoothing_param]
271
- n_samples_each_class += 2.0 * @params[:smoothing_param]
272
- n_classes = @classes.size
273
- @feature_probs = count_features / n_samples_each_class.reshape(n_classes, 1)
274
- self
275
- end
276
-
277
- # Calculate confidence scores for samples.
278
- #
279
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to compute the scores.
280
- # @return [Numo::DFloat] (shape: [n_samples, n_classes]) Confidence scores per sample for each class.
281
- def decision_function(x)
282
- SVMKit::Validation.check_sample_array(x)
283
- n_classes = @classes.size
284
- bin_x = Numo::DFloat[*x.gt(@params[:bin_threshold])]
285
- not_bin_x = Numo::DFloat[*x.le(@params[:bin_threshold])]
286
- log_likelihoods = Array.new(n_classes) do |l|
287
- Math.log(@class_priors[l]) + (
288
- (Numo::DFloat[*bin_x] * Numo::NMath.log(@feature_probs[l, true])).sum(1)
289
- (Numo::DFloat[*not_bin_x] * Numo::NMath.log(1.0 - @feature_probs[l, true])).sum(1))
290
- end
291
- Numo::DFloat[*log_likelihoods].transpose
292
- end
293
-
294
- # Dump marshal data.
295
- #
296
- # @return [Hash] The marshal data about BernoulliNB.
297
- def marshal_dump
298
- { params: @params,
299
- classes: @classes,
300
- class_priors: @class_priors,
301
- feature_probs: @feature_probs }
302
- end
303
-
304
- # Load marshal data.
305
- #
306
- # @return [nil]
307
- def marshal_load(obj)
308
- @params = obj[:params]
309
- @classes = obj[:classes]
310
- @class_priors = obj[:class_priors]
311
- @feature_probs = obj[:feature_probs]
312
- nil
313
- end
314
- end
315
- end
316
- end
@@ -1,112 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'svmkit/validation'
4
- require 'svmkit/base/base_estimator'
5
- require 'svmkit/base/classifier'
6
-
7
- module SVMKit
8
- # This module consists of the classes that implement estimators based on nearest neighbors rule.
9
- module NearestNeighbors
10
- # KNeighborsClassifier is a class that implements the classifier with the k-nearest neighbors rule.
11
- # The current implementation uses the Euclidean distance for finding the neighbors.
12
- #
13
- # @example
14
- # estimator =
15
- # SVMKit::NearestNeighbors::KNeighborsClassifier.new(n_neighbors = 5)
16
- # estimator.fit(training_samples, traininig_labels)
17
- # results = estimator.predict(testing_samples)
18
- #
19
- class KNeighborsClassifier
20
- include Base::BaseEstimator
21
- include Base::Classifier
22
-
23
- # Return the prototypes for the nearest neighbor classifier.
24
- # @return [Numo::DFloat] (shape: [n_samples, n_features])
25
- attr_reader :prototypes
26
-
27
- # Return the labels of the prototypes
28
- # @return [Numo::Int32] (size: n_samples)
29
- attr_reader :labels
30
-
31
- # Return the class labels.
32
- # @return [Numo::Int32] (size: n_classes)
33
- attr_reader :classes
34
-
35
- # Create a new classifier with the nearest neighbor rule.
36
- #
37
- # @param n_neighbors [Integer] The number of neighbors.
38
- def initialize(n_neighbors: 5)
39
- SVMKit::Validation.check_params_integer(n_neighbors: n_neighbors)
40
- SVMKit::Validation.check_params_positive(n_neighbors: n_neighbors)
41
- @params = {}
42
- @params[:n_neighbors] = n_neighbors
43
- @prototypes = nil
44
- @labels = nil
45
- @classes = nil
46
- end
47
-
48
- # Fit the model with given training data.
49
- #
50
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
51
- # @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
52
- # @return [KNeighborsClassifier] The learned classifier itself.
53
- def fit(x, y)
54
- SVMKit::Validation.check_sample_array(x)
55
- SVMKit::Validation.check_label_array(y)
56
- SVMKit::Validation.check_sample_label_size(x, y)
57
- @prototypes = Numo::DFloat.asarray(x.to_a)
58
- @labels = Numo::Int32.asarray(y.to_a)
59
- @classes = Numo::Int32.asarray(y.to_a.uniq.sort)
60
- self
61
- end
62
-
63
- # Calculate confidence scores for samples.
64
- #
65
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to compute the scores.
66
- # @return [Numo::DFloat] (shape: [n_samples, n_classes]) Confidence scores per sample for each class.
67
- def decision_function(x)
68
- SVMKit::Validation.check_sample_array(x)
69
- distance_matrix = PairwiseMetric.euclidean_distance(x, @prototypes)
70
- n_samples, n_prototypes = distance_matrix.shape
71
- n_classes = @classes.size
72
- n_neighbors = [@params[:n_neighbors], n_prototypes].min
73
- scores = Numo::DFloat.zeros(n_samples, n_classes)
74
- n_samples.times do |m|
75
- neighbor_ids = distance_matrix[m, true].to_a.each_with_index.sort.map(&:last)[0...n_neighbors]
76
- neighbor_ids.each { |n| scores[m, @classes.to_a.index(@labels[n])] += 1.0 }
77
- end
78
- scores
79
- end
80
-
81
- # Predict class labels for samples.
82
- #
83
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels.
84
- # @return [Numo::Int32] (shape: [n_samples]) Predicted class label per sample.
85
- def predict(x)
86
- SVMKit::Validation.check_sample_array(x)
87
- n_samples = x.shape.first
88
- decision_values = decision_function(x)
89
- Numo::Int32.asarray(Array.new(n_samples) { |n| @classes[decision_values[n, true].max_index] })
90
- end
91
-
92
- # Dump marshal data.
93
- # @return [Hash] The marshal data about KNeighborsClassifier.
94
- def marshal_dump
95
- { params: @params,
96
- prototypes: @prototypes,
97
- labels: @labels,
98
- classes: @classes }
99
- end
100
-
101
- # Load marshal data.
102
- # @return [nil]
103
- def marshal_load(obj)
104
- @params = obj[:params]
105
- @prototypes = obj[:prototypes]
106
- @labels = obj[:labels]
107
- @classes = obj[:classes]
108
- nil
109
- end
110
- end
111
- end
112
- end
@@ -1,94 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'svmkit/validation'
4
- require 'svmkit/base/base_estimator'
5
- require 'svmkit/base/regressor'
6
-
7
- module SVMKit
8
- module NearestNeighbors
9
- # KNeighborsRegressor is a class that implements the regressor with the k-nearest neighbors rule.
10
- # The current implementation uses the Euclidean distance for finding the neighbors.
11
- #
12
- # @example
13
- # estimator =
14
- # SVMKit::NearestNeighbors::KNeighborsRegressor.new(n_neighbors = 5)
15
- # estimator.fit(training_samples, traininig_target_values)
16
- # results = estimator.predict(testing_samples)
17
- #
18
- class KNeighborsRegressor
19
- include Base::BaseEstimator
20
- include Base::Regressor
21
-
22
- # Return the prototypes for the nearest neighbor regressor.
23
- # @return [Numo::DFloat] (shape: [n_samples, n_features])
24
- attr_reader :prototypes
25
-
26
- # Return the values of the prototypes
27
- # @return [Numo::DFloat] (shape: [n_samples, n_outputs])
28
- attr_reader :values
29
-
30
- # Create a new regressor with the nearest neighbor rule.
31
- #
32
- # @param n_neighbors [Integer] The number of neighbors.
33
- def initialize(n_neighbors: 5)
34
- SVMKit::Validation.check_params_integer(n_neighbors: n_neighbors)
35
- SVMKit::Validation.check_params_positive(n_neighbors: n_neighbors)
36
- @params = {}
37
- @params[:n_neighbors] = n_neighbors
38
- @prototypes = nil
39
- @values = nil
40
- end
41
-
42
- # Fit the model with given training data.
43
- #
44
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
45
- # @param y [Numo::DFloat] (shape: [n_samples, n_outputs]) The target values to be used for fitting the model.
46
- # @return [KNeighborsRegressor] The learned regressor itself.
47
- def fit(x, y)
48
- SVMKit::Validation.check_sample_array(x)
49
- SVMKit::Validation.check_tvalue_array(y)
50
- SVMKit::Validation.check_sample_tvalue_size(x, y)
51
- @prototypes = x.dup
52
- @values = y.dup
53
- self
54
- end
55
-
56
- # Predict values for samples.
57
- #
58
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the values.
59
- # @return [Numo::DFloat] (shape: [n_samples, n_outputs]) Predicted values per sample.
60
- def predict(x)
61
- SVMKit::Validation.check_sample_array(x)
62
- # Initialize some variables.
63
- n_samples, = x.shape
64
- n_prototypes, n_outputs = @values.shape
65
- n_neighbors = [@params[:n_neighbors], n_prototypes].min
66
- # Calculate distance matrix.
67
- distance_matrix = PairwiseMetric.euclidean_distance(x, @prototypes)
68
- # Predict values for the given samples.
69
- predicted_values = Array.new(n_samples) do |n|
70
- neighbor_ids = distance_matrix[n, true].to_a.each_with_index.sort.map(&:last)[0...n_neighbors]
71
- n_outputs.nil? ? @values[neighbor_ids].mean : @values[neighbor_ids, true].mean(0).to_a
72
- end
73
- Numo::DFloat[*predicted_values]
74
- end
75
-
76
- # Dump marshal data.
77
- # @return [Hash] The marshal data about KNeighborsRegressor.
78
- def marshal_dump
79
- { params: @params,
80
- prototypes: @prototypes,
81
- values: @values }
82
- end
83
-
84
- # Load marshal data.
85
- # @return [nil]
86
- def marshal_load(obj)
87
- @params = obj[:params]
88
- @prototypes = obj[:prototypes]
89
- @values = obj[:values]
90
- nil
91
- end
92
- end
93
- end
94
- end