svmkit 0.7.3 → 0.8.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (78) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +0 -9
  3. data/.rspec +1 -0
  4. data/.travis.yml +4 -12
  5. data/LICENSE.txt +1 -1
  6. data/README.md +11 -13
  7. data/lib/svmkit.rb +3 -66
  8. data/svmkit.gemspec +12 -7
  9. metadata +16 -81
  10. data/.coveralls.yml +0 -1
  11. data/.rubocop.yml +0 -47
  12. data/.rubocop_todo.yml +0 -58
  13. data/HISTORY.md +0 -168
  14. data/lib/svmkit/base/base_estimator.rb +0 -13
  15. data/lib/svmkit/base/classifier.rb +0 -34
  16. data/lib/svmkit/base/cluster_analyzer.rb +0 -29
  17. data/lib/svmkit/base/evaluator.rb +0 -13
  18. data/lib/svmkit/base/regressor.rb +0 -34
  19. data/lib/svmkit/base/splitter.rb +0 -17
  20. data/lib/svmkit/base/transformer.rb +0 -18
  21. data/lib/svmkit/clustering/dbscan.rb +0 -127
  22. data/lib/svmkit/clustering/k_means.rb +0 -140
  23. data/lib/svmkit/dataset.rb +0 -109
  24. data/lib/svmkit/decomposition/nmf.rb +0 -147
  25. data/lib/svmkit/decomposition/pca.rb +0 -150
  26. data/lib/svmkit/ensemble/ada_boost_classifier.rb +0 -198
  27. data/lib/svmkit/ensemble/ada_boost_regressor.rb +0 -180
  28. data/lib/svmkit/ensemble/random_forest_classifier.rb +0 -182
  29. data/lib/svmkit/ensemble/random_forest_regressor.rb +0 -143
  30. data/lib/svmkit/evaluation_measure/accuracy.rb +0 -30
  31. data/lib/svmkit/evaluation_measure/f_score.rb +0 -51
  32. data/lib/svmkit/evaluation_measure/log_loss.rb +0 -46
  33. data/lib/svmkit/evaluation_measure/mean_absolute_error.rb +0 -30
  34. data/lib/svmkit/evaluation_measure/mean_squared_error.rb +0 -30
  35. data/lib/svmkit/evaluation_measure/normalized_mutual_information.rb +0 -63
  36. data/lib/svmkit/evaluation_measure/precision.rb +0 -51
  37. data/lib/svmkit/evaluation_measure/precision_recall.rb +0 -91
  38. data/lib/svmkit/evaluation_measure/purity.rb +0 -41
  39. data/lib/svmkit/evaluation_measure/r2_score.rb +0 -44
  40. data/lib/svmkit/evaluation_measure/recall.rb +0 -51
  41. data/lib/svmkit/kernel_approximation/rbf.rb +0 -136
  42. data/lib/svmkit/kernel_machine/kernel_svc.rb +0 -194
  43. data/lib/svmkit/linear_model/lasso.rb +0 -138
  44. data/lib/svmkit/linear_model/linear_regression.rb +0 -112
  45. data/lib/svmkit/linear_model/logistic_regression.rb +0 -161
  46. data/lib/svmkit/linear_model/ridge.rb +0 -112
  47. data/lib/svmkit/linear_model/sgd_linear_estimator.rb +0 -89
  48. data/lib/svmkit/linear_model/svc.rb +0 -184
  49. data/lib/svmkit/linear_model/svr.rb +0 -123
  50. data/lib/svmkit/model_selection/cross_validation.rb +0 -121
  51. data/lib/svmkit/model_selection/grid_search_cv.rb +0 -247
  52. data/lib/svmkit/model_selection/k_fold.rb +0 -77
  53. data/lib/svmkit/model_selection/stratified_k_fold.rb +0 -95
  54. data/lib/svmkit/multiclass/one_vs_rest_classifier.rb +0 -101
  55. data/lib/svmkit/naive_bayes/naive_bayes.rb +0 -316
  56. data/lib/svmkit/nearest_neighbors/k_neighbors_classifier.rb +0 -112
  57. data/lib/svmkit/nearest_neighbors/k_neighbors_regressor.rb +0 -94
  58. data/lib/svmkit/optimizer/nadam.rb +0 -90
  59. data/lib/svmkit/optimizer/rmsprop.rb +0 -69
  60. data/lib/svmkit/optimizer/sgd.rb +0 -65
  61. data/lib/svmkit/optimizer/yellow_fin.rb +0 -144
  62. data/lib/svmkit/pairwise_metric.rb +0 -91
  63. data/lib/svmkit/pipeline/pipeline.rb +0 -197
  64. data/lib/svmkit/polynomial_model/factorization_machine_classifier.rb +0 -262
  65. data/lib/svmkit/polynomial_model/factorization_machine_regressor.rb +0 -194
  66. data/lib/svmkit/preprocessing/l2_normalizer.rb +0 -63
  67. data/lib/svmkit/preprocessing/label_encoder.rb +0 -95
  68. data/lib/svmkit/preprocessing/min_max_scaler.rb +0 -93
  69. data/lib/svmkit/preprocessing/one_hot_encoder.rb +0 -99
  70. data/lib/svmkit/preprocessing/standard_scaler.rb +0 -87
  71. data/lib/svmkit/probabilistic_output.rb +0 -112
  72. data/lib/svmkit/tree/decision_tree_classifier.rb +0 -276
  73. data/lib/svmkit/tree/decision_tree_regressor.rb +0 -251
  74. data/lib/svmkit/tree/node.rb +0 -70
  75. data/lib/svmkit/utils.rb +0 -22
  76. data/lib/svmkit/validation.rb +0 -79
  77. data/lib/svmkit/values.rb +0 -13
  78. data/lib/svmkit/version.rb +0 -7
@@ -1,91 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'svmkit/base/evaluator'
4
-
5
- module SVMKit
6
- # This module consists of the classes for model evaluation.
7
- module EvaluationMeasure
8
- # @!visibility private
9
- module PrecisionRecall
10
- module_function
11
-
12
- # @!visibility private
13
- def precision_each_class(y_true, y_pred)
14
- y_true.sort.to_a.uniq.map do |label|
15
- target_positions = y_pred.eq(label)
16
- next 0.0 if y_pred[target_positions].empty?
17
- n_true_positives = Numo::Int32.cast(y_true[target_positions].eq(y_pred[target_positions])).sum.to_f
18
- n_false_positives = Numo::Int32.cast(y_true[target_positions].ne(y_pred[target_positions])).sum.to_f
19
- n_true_positives / (n_true_positives + n_false_positives)
20
- end
21
- end
22
-
23
- # @!visibility private
24
- def recall_each_class(y_true, y_pred)
25
- y_true.sort.to_a.uniq.map do |label|
26
- target_positions = y_true.eq(label)
27
- next 0.0 if y_pred[target_positions].empty?
28
- n_true_positives = Numo::Int32.cast(y_true[target_positions].eq(y_pred[target_positions])).sum.to_f
29
- n_false_negatives = Numo::Int32.cast(y_true[target_positions].ne(y_pred[target_positions])).sum.to_f
30
- n_true_positives / (n_true_positives + n_false_negatives)
31
- end
32
- end
33
-
34
- # @!visibility private
35
- def f_score_each_class(y_true, y_pred)
36
- precision_each_class(y_true, y_pred).zip(recall_each_class(y_true, y_pred)).map do |p, r|
37
- next 0.0 if p.zero? && r.zero?
38
- (2.0 * p * r) / (p + r)
39
- end
40
- end
41
-
42
- # @!visibility private
43
- def micro_average_precision(y_true, y_pred)
44
- evaluated_values = y_true.sort.to_a.uniq.map do |label|
45
- target_positions = y_pred.eq(label)
46
- next [0.0, 0.0] if y_pred[target_positions].empty?
47
- n_true_positives = Numo::Int32.cast(y_true[target_positions].eq(y_pred[target_positions])).sum.to_f
48
- n_false_positives = Numo::Int32.cast(y_true[target_positions].ne(y_pred[target_positions])).sum.to_f
49
- [n_true_positives, n_true_positives + n_false_positives]
50
- end
51
- res = evaluated_values.transpose.map { |v| v.inject(:+) }
52
- res.first / res.last
53
- end
54
-
55
- # @!visibility private
56
- def micro_average_recall(y_true, y_pred)
57
- evaluated_values = y_true.sort.to_a.uniq.map do |label|
58
- target_positions = y_true.eq(label)
59
- next 0.0 if y_pred[target_positions].empty?
60
- n_true_positives = Numo::Int32.cast(y_true[target_positions].eq(y_pred[target_positions])).sum.to_f
61
- n_false_negatives = Numo::Int32.cast(y_true[target_positions].ne(y_pred[target_positions])).sum.to_f
62
- [n_true_positives, n_true_positives + n_false_negatives]
63
- end
64
- res = evaluated_values.transpose.map { |v| v.inject(:+) }
65
- res.first / res.last
66
- end
67
-
68
- # @!visibility private
69
- def micro_average_f_score(y_true, y_pred)
70
- p = micro_average_precision(y_true, y_pred)
71
- r = micro_average_recall(y_true, y_pred)
72
- (2.0 * p * r) / (p + r)
73
- end
74
-
75
- # @!visibility private
76
- def macro_average_precision(y_true, y_pred)
77
- precision_each_class(y_true, y_pred).inject(:+) / y_true.to_a.uniq.size
78
- end
79
-
80
- # @!visibility private
81
- def macro_average_recall(y_true, y_pred)
82
- recall_each_class(y_true, y_pred).inject(:+) / y_true.to_a.uniq.size
83
- end
84
-
85
- # @!visibility private
86
- def macro_average_f_score(y_true, y_pred)
87
- f_score_each_class(y_true, y_pred).inject(:+) / y_true.to_a.uniq.size
88
- end
89
- end
90
- end
91
- end
@@ -1,41 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'svmkit/validation'
4
- require 'svmkit/base/evaluator'
5
-
6
- module SVMKit
7
- module EvaluationMeasure
8
- # Purity is a class that calculates the purity of cluatering results.
9
- #
10
- # @example
11
- # evaluator = SVMKit::EvaluationMeasure::Purity.new
12
- # puts evaluator.score(ground_truth, predicted)
13
- #
14
- # *Reference*
15
- # - C D. Manning, P. Raghavan, and H. Schutze, "Introduction to Information Retrieval," Cambridge University Press., 2008.
16
- class Purity
17
- include Base::Evaluator
18
-
19
- # Calculate purity
20
- #
21
- # @param y_true [Numo::Int32] (shape: [n_samples]) Ground truth labels.
22
- # @param y_pred [Numo::Int32] (shape: [n_samples]) Predicted cluster labels.
23
- # @return [Float] Purity
24
- def score(y_true, y_pred)
25
- SVMKit::Validation.check_label_array(y_true)
26
- SVMKit::Validation.check_label_array(y_pred)
27
- # initiazlie some variables.
28
- purity = 0
29
- n_samples = y_pred.size
30
- class_ids = y_true.to_a.uniq
31
- cluster_ids = y_pred.to_a.uniq
32
- # calculate purity.
33
- cluster_ids.each do |k|
34
- pr_sample_ids = y_pred.eq(k).where.to_a
35
- purity += class_ids.map { |j| (pr_sample_ids & y_true.eq(j).where.to_a).size }.max
36
- end
37
- purity.fdiv(n_samples)
38
- end
39
- end
40
- end
41
- end
@@ -1,44 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'svmkit/validation'
4
- require 'svmkit/base/evaluator'
5
- require 'svmkit/evaluation_measure/precision_recall'
6
-
7
- module SVMKit
8
- module EvaluationMeasure
9
- # R2Score is a class that calculates the coefficient of determination for the predicted values.
10
- #
11
- # @example
12
- # evaluator = SVMKit::EvaluationMeasure::R2Score.new
13
- # puts evaluator.score(ground_truth, predicted)
14
- class R2Score
15
- include Base::Evaluator
16
-
17
- # Create a new evaluation measure calculater for coefficient of determination.
18
- def initialize; end
19
-
20
- # Calculate the coefficient of determination.
21
- #
22
- # @param y_true [Numo::DFloat] (shape: [n_samples, n_outputs]) Ground truth target values.
23
- # @param y_pred [Numo::DFloat] (shape: [n_samples, n_outputs]) Estimated taget values.
24
- # @return [Float] Coefficient of determination
25
- def score(y_true, y_pred)
26
- SVMKit::Validation.check_tvalue_array(y_true)
27
- SVMKit::Validation.check_tvalue_array(y_pred)
28
- raise ArgumentError, 'Expect to have the same size both y_true and y_pred.' unless y_true.shape == y_pred.shape
29
-
30
- n_samples, n_outputs = y_true.shape
31
- numerator = ((y_true - y_pred)**2).sum(0)
32
- yt_mean = y_true.sum(0) / n_samples
33
- denominator = ((y_true - yt_mean)**2).sum(0)
34
- if n_outputs.nil?
35
- denominator.zero? ? 0.0 : 1.0 - numerator / denominator
36
- else
37
- scores = 1 - numerator / denominator
38
- scores[denominator.eq(0)] = 0.0
39
- scores.sum / scores.size
40
- end
41
- end
42
- end
43
- end
44
- end
@@ -1,51 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'svmkit/validation'
4
- require 'svmkit/base/evaluator'
5
- require 'svmkit/evaluation_measure/precision_recall'
6
-
7
- module SVMKit
8
- # This module consists of the classes for model evaluation.
9
- module EvaluationMeasure
10
- # Recall is a class that calculates the recall of the predicted labels.
11
- #
12
- # @example
13
- # evaluator = SVMKit::EvaluationMeasure::Recall.new
14
- # puts evaluator.score(ground_truth, predicted)
15
- class Recall
16
- include Base::Evaluator
17
- include EvaluationMeasure::PrecisionRecall
18
-
19
- # Return the average type for calculation of recall.
20
- # @return [String] ('binary', 'micro', 'macro')
21
- attr_reader :average
22
-
23
- # Create a new evaluation measure calculater for recall score.
24
- #
25
- # @param average [String] The average type ('binary', 'micro', 'macro')
26
- def initialize(average: 'binary')
27
- SVMKit::Validation.check_params_string(average: average)
28
- @average = average
29
- end
30
-
31
- # Calculate average recall
32
- #
33
- # @param y_true [Numo::Int32] (shape: [n_samples]) Ground truth labels.
34
- # @param y_pred [Numo::Int32] (shape: [n_samples]) Predicted labels.
35
- # @return [Float] Average recall
36
- def score(y_true, y_pred)
37
- SVMKit::Validation.check_label_array(y_true)
38
- SVMKit::Validation.check_label_array(y_pred)
39
-
40
- case @average
41
- when 'binary'
42
- recall_each_class(y_true, y_pred).last
43
- when 'micro'
44
- micro_average_recall(y_true, y_pred)
45
- when 'macro'
46
- macro_average_recall(y_true, y_pred)
47
- end
48
- end
49
- end
50
- end
51
- end
@@ -1,136 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'svmkit/validation'
4
- require 'svmkit/base/base_estimator'
5
- require 'svmkit/base/transformer'
6
-
7
- module SVMKit
8
- # Module for kernel approximation algorithms.
9
- module KernelApproximation
10
- # Class for RBF kernel feature mapping.
11
- #
12
- # @example
13
- # transformer = SVMKit::KernelApproximation::RBF.new(gamma: 1.0, n_coponents: 128, random_seed: 1)
14
- # new_training_samples = transformer.fit_transform(training_samples)
15
- # new_testing_samples = transformer.transform(testing_samples)
16
- #
17
- # *Refernce*:
18
- # 1. A. Rahimi and B. Recht, "Random Features for Large-Scale Kernel Machines," Proc. NIPS'07, pp.1177--1184, 2007.
19
- class RBF
20
- include Base::BaseEstimator
21
- include Base::Transformer
22
-
23
- # Return the random matrix for transformation.
24
- # @return [Numo::DFloat] (shape: [n_features, n_components])
25
- attr_reader :random_mat
26
-
27
- # Return the random vector for transformation.
28
- # @return [Numo::DFloat] (shape: [n_components])
29
- attr_reader :random_vec
30
-
31
- # Return the random generator for transformation.
32
- # @return [Random]
33
- attr_reader :rng
34
-
35
- # Create a new transformer for mapping to RBF kernel feature space.
36
- #
37
- # @param gamma [Float] The parameter of RBF kernel: exp(-gamma * x^2).
38
- # @param n_components [Integer] The number of dimensions of the RBF kernel feature space.
39
- # @param random_seed [Integer] The seed value using to initialize the random generator.
40
- def initialize(gamma: 1.0, n_components: 128, random_seed: nil)
41
- SVMKit::Validation.check_params_float(gamma: gamma)
42
- SVMKit::Validation.check_params_integer(n_components: n_components)
43
- SVMKit::Validation.check_params_type_or_nil(Integer, random_seed: random_seed)
44
- SVMKit::Validation.check_params_positive(gamma: gamma, n_components: n_components)
45
- @params = {}
46
- @params[:gamma] = gamma
47
- @params[:n_components] = n_components
48
- @params[:random_seed] = random_seed
49
- @params[:random_seed] ||= srand
50
- @random_mat = nil
51
- @random_vec = nil
52
- @rng = Random.new(@params[:random_seed])
53
- end
54
-
55
- # Fit the model with given training data.
56
- #
57
- # @overload fit(x) -> RBF
58
- #
59
- # @param x [Numo::NArray] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
60
- # This method uses only the number of features of the data.
61
- # @return [RBF] The learned transformer itself.
62
- def fit(x, _y = nil)
63
- SVMKit::Validation.check_sample_array(x)
64
-
65
- n_features = x.shape[1]
66
- @params[:n_components] = 2 * n_features if @params[:n_components] <= 0
67
- @random_mat = rand_normal([n_features, @params[:n_components]]) * (2.0 * @params[:gamma])**0.5
68
- n_half_components = @params[:n_components] / 2
69
- @random_vec = Numo::DFloat.zeros(@params[:n_components] - n_half_components).concatenate(
70
- Numo::DFloat.ones(n_half_components) * (0.5 * Math::PI)
71
- )
72
- self
73
- end
74
-
75
- # Fit the model with training data, and then transform them with the learned model.
76
- #
77
- # @overload fit_transform(x) -> Numo::DFloat
78
- #
79
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
80
- # @return [Numo::DFloat] (shape: [n_samples, n_components]) The transformed data
81
- def fit_transform(x, _y = nil)
82
- SVMKit::Validation.check_sample_array(x)
83
-
84
- fit(x).transform(x)
85
- end
86
-
87
- # Transform the given data with the learned model.
88
- #
89
- # @overload transform(x) -> Numo::DFloat
90
- #
91
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The data to be transformed with the learned model.
92
- # @return [Numo::DFloat] (shape: [n_samples, n_components]) The transformed data.
93
- def transform(x)
94
- SVMKit::Validation.check_sample_array(x)
95
-
96
- n_samples, = x.shape
97
- projection = x.dot(@random_mat) + @random_vec.tile(n_samples, 1)
98
- Numo::NMath.sin(projection) * ((2.0 / @params[:n_components])**0.5)
99
- end
100
-
101
- # Dump marshal data.
102
- # @return [Hash] The marshal data about RBF.
103
- def marshal_dump
104
- { params: @params,
105
- random_mat: @random_mat,
106
- random_vec: @random_vec,
107
- rng: @rng }
108
- end
109
-
110
- # Load marshal data.
111
- # @return [nil]
112
- def marshal_load(obj)
113
- @params = obj[:params]
114
- @random_mat = obj[:random_mat]
115
- @random_vec = obj[:random_vec]
116
- @rng = obj[:rng]
117
- nil
118
- end
119
-
120
- private
121
-
122
- # Generate the uniform random matrix with the given shape.
123
- def rand_uniform(shape)
124
- rnd_vals = Array.new(shape.inject(:*)) { @rng.rand }
125
- Numo::DFloat.asarray(rnd_vals).reshape(shape[0], shape[1])
126
- end
127
-
128
- # Generate the normal random matrix with the given shape, mean, and standard deviation.
129
- def rand_normal(shape, mu = 0.0, sigma = 1.0)
130
- a = rand_uniform(shape)
131
- b = rand_uniform(shape)
132
- (Numo::NMath.sqrt(Numo::NMath.log(a) * -2.0) * Numo::NMath.sin(b * 2.0 * Math::PI)) * sigma + mu
133
- end
134
- end
135
- end
136
- end
@@ -1,194 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'svmkit/validation'
4
- require 'svmkit/base/base_estimator'
5
- require 'svmkit/base/classifier'
6
- require 'svmkit/probabilistic_output'
7
-
8
- module SVMKit
9
- # This module consists of the classes that implement kernel method-based estimator.
10
- module KernelMachine
11
- # KernelSVC is a class that implements (Nonlinear) Kernel Support Vector Classifier
12
- # with stochastic gradient descent (SGD) optimization.
13
- # For multiclass classification problem, it uses one-vs-the-rest strategy.
14
- #
15
- # @example
16
- # training_kernel_matrix = SVMKit::PairwiseMetric::rbf_kernel(training_samples)
17
- # estimator =
18
- # SVMKit::KernelMachine::KernelSVC.new(reg_param: 1.0, max_iter: 1000, random_seed: 1)
19
- # estimator.fit(training_kernel_matrix, traininig_labels)
20
- # testing_kernel_matrix = SVMKit::PairwiseMetric::rbf_kernel(testing_samples, training_samples)
21
- # results = estimator.predict(testing_kernel_matrix)
22
- #
23
- # *Reference*
24
- # 1. S. Shalev-Shwartz, Y. Singer, N. Srebro, and A. Cotter, "Pegasos: Primal Estimated sub-GrAdient SOlver for SVM," Mathematical Programming, vol. 127 (1), pp. 3--30, 2011.
25
- class KernelSVC
26
- include Base::BaseEstimator
27
- include Base::Classifier
28
-
29
- # Return the weight vector for Kernel SVC.
30
- # @return [Numo::DFloat] (shape: [n_classes, n_trainig_sample])
31
- attr_reader :weight_vec
32
-
33
- # Return the class labels.
34
- # @return [Numo::Int32] (shape: [n_classes])
35
- attr_reader :classes
36
-
37
- # Return the random generator for performing random sampling.
38
- # @return [Random]
39
- attr_reader :rng
40
-
41
- # Create a new classifier with Kernel Support Vector Machine by the SGD optimization.
42
- #
43
- # @param reg_param [Float] The regularization parameter.
44
- # @param max_iter [Integer] The maximum number of iterations.
45
- # @param probability [Boolean] The flag indicating whether to perform probability estimation.
46
- # @param random_seed [Integer] The seed value using to initialize the random generator.
47
- def initialize(reg_param: 1.0, max_iter: 1000, probability: false, random_seed: nil)
48
- SVMKit::Validation.check_params_float(reg_param: reg_param)
49
- SVMKit::Validation.check_params_integer(max_iter: max_iter)
50
- SVMKit::Validation.check_params_boolean(probability: probability)
51
- SVMKit::Validation.check_params_type_or_nil(Integer, random_seed: random_seed)
52
- SVMKit::Validation.check_params_positive(reg_param: reg_param, max_iter: max_iter)
53
- @params = {}
54
- @params[:reg_param] = reg_param
55
- @params[:max_iter] = max_iter
56
- @params[:probability] = probability
57
- @params[:random_seed] = random_seed
58
- @params[:random_seed] ||= srand
59
- @weight_vec = nil
60
- @prob_param = nil
61
- @classes = nil
62
- @rng = Random.new(@params[:random_seed])
63
- end
64
-
65
- # Fit the model with given training data.
66
- #
67
- # @param x [Numo::DFloat] (shape: [n_training_samples, n_training_samples])
68
- # The kernel matrix of the training data to be used for fitting the model.
69
- # @param y [Numo::Int32] (shape: [n_training_samples]) The labels to be used for fitting the model.
70
- # @return [KernelSVC] The learned classifier itself.
71
- def fit(x, y)
72
- SVMKit::Validation.check_sample_array(x)
73
- SVMKit::Validation.check_label_array(y)
74
- SVMKit::Validation.check_sample_label_size(x, y)
75
-
76
- @classes = Numo::Int32[*y.to_a.uniq.sort]
77
- n_classes = @classes.size
78
- _n_samples, n_features = x.shape
79
-
80
- if n_classes > 2
81
- @weight_vec = Numo::DFloat.zeros(n_classes, n_features)
82
- @prob_param = Numo::DFloat.zeros(n_classes, 2)
83
- n_classes.times do |n|
84
- bin_y = Numo::Int32.cast(y.eq(@classes[n])) * 2 - 1
85
- @weight_vec[n, true] = binary_fit(x, bin_y)
86
- @prob_param[n, true] = if @params[:probability]
87
- SVMKit::ProbabilisticOutput.fit_sigmoid(x.dot(@weight_vec[n, true].transpose), bin_y)
88
- else
89
- Numo::DFloat[1, 0]
90
- end
91
- end
92
- else
93
- negative_label = y.to_a.uniq.min
94
- bin_y = Numo::Int32.cast(y.ne(negative_label)) * 2 - 1
95
- @weight_vec = binary_fit(x, bin_y)
96
- @prob_param = if @params[:probability]
97
- SVMKit::ProbabilisticOutput.fit_sigmoid(x.dot(@weight_vec.transpose), bin_y)
98
- else
99
- Numo::DFloat[1, 0]
100
- end
101
- end
102
-
103
- self
104
- end
105
-
106
- # Calculate confidence scores for samples.
107
- #
108
- # @param x [Numo::DFloat] (shape: [n_testing_samples, n_training_samples])
109
- # The kernel matrix between testing samples and training samples to compute the scores.
110
- # @return [Numo::DFloat] (shape: [n_testing_samples, n_classes]) Confidence score per sample.
111
- def decision_function(x)
112
- SVMKit::Validation.check_sample_array(x)
113
-
114
- x.dot(@weight_vec.transpose)
115
- end
116
-
117
- # Predict class labels for samples.
118
- #
119
- # @param x [Numo::DFloat] (shape: [n_testing_samples, n_training_samples])
120
- # The kernel matrix between testing samples and training samples to predict the labels.
121
- # @return [Numo::Int32] (shape: [n_testing_samples]) Predicted class label per sample.
122
- def predict(x)
123
- SVMKit::Validation.check_sample_array(x)
124
-
125
- return Numo::Int32.cast(decision_function(x).ge(0.0)) * 2 - 1 if @classes.size <= 2
126
-
127
- n_samples, = x.shape
128
- decision_values = decision_function(x)
129
- Numo::Int32.asarray(Array.new(n_samples) { |n| @classes[decision_values[n, true].max_index] })
130
- end
131
-
132
- # Predict probability for samples.
133
- #
134
- # @param x [Numo::DFloat] (shape: [n_testing_samples, n_training_samples])
135
- # The kernel matrix between testing samples and training samples to predict the labels.
136
- # @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted probability of each class per sample.
137
- def predict_proba(x)
138
- SVMKit::Validation.check_sample_array(x)
139
-
140
- if @classes.size > 2
141
- probs = 1.0 / (Numo::NMath.exp(@prob_param[true, 0] * decision_function(x) + @prob_param[true, 1]) + 1.0)
142
- return (probs.transpose / probs.sum(axis: 1)).transpose
143
- end
144
-
145
- n_samples, = x.shape
146
- probs = Numo::DFloat.zeros(n_samples, 2)
147
- probs[true, 1] = 1.0 / (Numo::NMath.exp(@prob_param[0] * decision_function(x) + @prob_param[1]) + 1.0)
148
- probs[true, 0] = 1.0 - probs[true, 1]
149
- probs
150
- end
151
-
152
- # Dump marshal data.
153
- # @return [Hash] The marshal data about KernelSVC.
154
- def marshal_dump
155
- { params: @params,
156
- weight_vec: @weight_vec,
157
- prob_param: @prob_param,
158
- classes: @classes,
159
- rng: @rng }
160
- end
161
-
162
- # Load marshal data.
163
- # @return [nil]
164
- def marshal_load(obj)
165
- @params = obj[:params]
166
- @weight_vec = obj[:weight_vec]
167
- @prob_param = obj[:prob_param]
168
- @classes = obj[:classes]
169
- @rng = obj[:rng]
170
- nil
171
- end
172
-
173
- private
174
-
175
- def binary_fit(x, bin_y)
176
- # Initialize some variables.
177
- n_training_samples = x.shape[0]
178
- rand_ids = []
179
- weight_vec = Numo::DFloat.zeros(n_training_samples)
180
- # Start optimization.
181
- @params[:max_iter].times do |t|
182
- # random sampling
183
- rand_ids = [*0...n_training_samples].shuffle(random: @rng) if rand_ids.empty?
184
- target_id = rand_ids.shift
185
- # update the weight vector
186
- func = (weight_vec * bin_y).dot(x[target_id, true].transpose).to_f
187
- func *= bin_y[target_id] / (@params[:reg_param] * (t + 1))
188
- weight_vec[target_id] += 1.0 if func < 1.0
189
- end
190
- weight_vec * bin_y
191
- end
192
- end
193
- end
194
- end