svmkit 0.7.3 → 0.8.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (78) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +0 -9
  3. data/.rspec +1 -0
  4. data/.travis.yml +4 -12
  5. data/LICENSE.txt +1 -1
  6. data/README.md +11 -13
  7. data/lib/svmkit.rb +3 -66
  8. data/svmkit.gemspec +12 -7
  9. metadata +16 -81
  10. data/.coveralls.yml +0 -1
  11. data/.rubocop.yml +0 -47
  12. data/.rubocop_todo.yml +0 -58
  13. data/HISTORY.md +0 -168
  14. data/lib/svmkit/base/base_estimator.rb +0 -13
  15. data/lib/svmkit/base/classifier.rb +0 -34
  16. data/lib/svmkit/base/cluster_analyzer.rb +0 -29
  17. data/lib/svmkit/base/evaluator.rb +0 -13
  18. data/lib/svmkit/base/regressor.rb +0 -34
  19. data/lib/svmkit/base/splitter.rb +0 -17
  20. data/lib/svmkit/base/transformer.rb +0 -18
  21. data/lib/svmkit/clustering/dbscan.rb +0 -127
  22. data/lib/svmkit/clustering/k_means.rb +0 -140
  23. data/lib/svmkit/dataset.rb +0 -109
  24. data/lib/svmkit/decomposition/nmf.rb +0 -147
  25. data/lib/svmkit/decomposition/pca.rb +0 -150
  26. data/lib/svmkit/ensemble/ada_boost_classifier.rb +0 -198
  27. data/lib/svmkit/ensemble/ada_boost_regressor.rb +0 -180
  28. data/lib/svmkit/ensemble/random_forest_classifier.rb +0 -182
  29. data/lib/svmkit/ensemble/random_forest_regressor.rb +0 -143
  30. data/lib/svmkit/evaluation_measure/accuracy.rb +0 -30
  31. data/lib/svmkit/evaluation_measure/f_score.rb +0 -51
  32. data/lib/svmkit/evaluation_measure/log_loss.rb +0 -46
  33. data/lib/svmkit/evaluation_measure/mean_absolute_error.rb +0 -30
  34. data/lib/svmkit/evaluation_measure/mean_squared_error.rb +0 -30
  35. data/lib/svmkit/evaluation_measure/normalized_mutual_information.rb +0 -63
  36. data/lib/svmkit/evaluation_measure/precision.rb +0 -51
  37. data/lib/svmkit/evaluation_measure/precision_recall.rb +0 -91
  38. data/lib/svmkit/evaluation_measure/purity.rb +0 -41
  39. data/lib/svmkit/evaluation_measure/r2_score.rb +0 -44
  40. data/lib/svmkit/evaluation_measure/recall.rb +0 -51
  41. data/lib/svmkit/kernel_approximation/rbf.rb +0 -136
  42. data/lib/svmkit/kernel_machine/kernel_svc.rb +0 -194
  43. data/lib/svmkit/linear_model/lasso.rb +0 -138
  44. data/lib/svmkit/linear_model/linear_regression.rb +0 -112
  45. data/lib/svmkit/linear_model/logistic_regression.rb +0 -161
  46. data/lib/svmkit/linear_model/ridge.rb +0 -112
  47. data/lib/svmkit/linear_model/sgd_linear_estimator.rb +0 -89
  48. data/lib/svmkit/linear_model/svc.rb +0 -184
  49. data/lib/svmkit/linear_model/svr.rb +0 -123
  50. data/lib/svmkit/model_selection/cross_validation.rb +0 -121
  51. data/lib/svmkit/model_selection/grid_search_cv.rb +0 -247
  52. data/lib/svmkit/model_selection/k_fold.rb +0 -77
  53. data/lib/svmkit/model_selection/stratified_k_fold.rb +0 -95
  54. data/lib/svmkit/multiclass/one_vs_rest_classifier.rb +0 -101
  55. data/lib/svmkit/naive_bayes/naive_bayes.rb +0 -316
  56. data/lib/svmkit/nearest_neighbors/k_neighbors_classifier.rb +0 -112
  57. data/lib/svmkit/nearest_neighbors/k_neighbors_regressor.rb +0 -94
  58. data/lib/svmkit/optimizer/nadam.rb +0 -90
  59. data/lib/svmkit/optimizer/rmsprop.rb +0 -69
  60. data/lib/svmkit/optimizer/sgd.rb +0 -65
  61. data/lib/svmkit/optimizer/yellow_fin.rb +0 -144
  62. data/lib/svmkit/pairwise_metric.rb +0 -91
  63. data/lib/svmkit/pipeline/pipeline.rb +0 -197
  64. data/lib/svmkit/polynomial_model/factorization_machine_classifier.rb +0 -262
  65. data/lib/svmkit/polynomial_model/factorization_machine_regressor.rb +0 -194
  66. data/lib/svmkit/preprocessing/l2_normalizer.rb +0 -63
  67. data/lib/svmkit/preprocessing/label_encoder.rb +0 -95
  68. data/lib/svmkit/preprocessing/min_max_scaler.rb +0 -93
  69. data/lib/svmkit/preprocessing/one_hot_encoder.rb +0 -99
  70. data/lib/svmkit/preprocessing/standard_scaler.rb +0 -87
  71. data/lib/svmkit/probabilistic_output.rb +0 -112
  72. data/lib/svmkit/tree/decision_tree_classifier.rb +0 -276
  73. data/lib/svmkit/tree/decision_tree_regressor.rb +0 -251
  74. data/lib/svmkit/tree/node.rb +0 -70
  75. data/lib/svmkit/utils.rb +0 -22
  76. data/lib/svmkit/validation.rb +0 -79
  77. data/lib/svmkit/values.rb +0 -13
  78. data/lib/svmkit/version.rb +0 -7
@@ -1,143 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'svmkit/validation'
4
- require 'svmkit/values'
5
- require 'svmkit/base/base_estimator'
6
- require 'svmkit/base/regressor'
7
- require 'svmkit/tree/decision_tree_regressor'
8
-
9
- module SVMKit
10
- module Ensemble
11
- # RandomForestRegressor is a class that implements random forest for regression
12
- #
13
- # @example
14
- # estimator =
15
- # SVMKit::Ensemble::RandomForestRegressor.new(
16
- # n_estimators: 10, criterion: 'mse', max_depth: 3, max_leaf_nodes: 10, min_samples_leaf: 5, random_seed: 1)
17
- # estimator.fit(training_samples, traininig_values)
18
- # results = estimator.predict(testing_samples)
19
- #
20
- class RandomForestRegressor
21
- include Base::BaseEstimator
22
- include Base::Regressor
23
- include Validation
24
-
25
- # Return the set of estimators.
26
- # @return [Array<DecisionTreeRegressor>]
27
- attr_reader :estimators
28
-
29
- # Return the importance for each feature.
30
- # @return [Numo::DFloat] (size: n_features)
31
- attr_reader :feature_importances
32
-
33
- # Return the random generator for random selection of feature index.
34
- # @return [Random]
35
- attr_reader :rng
36
-
37
- # Create a new regressor with random forest.
38
- #
39
- # @param n_estimators [Integer] The numeber of decision trees for contructing random forest.
40
- # @param criterion [String] The function to evalue spliting point. Supported criteria are 'gini' and 'entropy'.
41
- # @param max_depth [Integer] The maximum depth of the tree.
42
- # If nil is given, decision tree grows without concern for depth.
43
- # @param max_leaf_nodes [Integer] The maximum number of leaves on decision tree.
44
- # If nil is given, number of leaves is not limited.
45
- # @param min_samples_leaf [Integer] The minimum number of samples at a leaf node.
46
- # @param max_features [Integer] The number of features to consider when searching optimal split point.
47
- # If nil is given, split process considers all features.
48
- # @param random_seed [Integer] The seed value using to initialize the random generator.
49
- # It is used to randomly determine the order of features when deciding spliting point.
50
- def initialize(n_estimators: 10,
51
- criterion: 'mse', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1,
52
- max_features: nil, random_seed: nil)
53
- check_params_type_or_nil(Integer, max_depth: max_depth, max_leaf_nodes: max_leaf_nodes,
54
- max_features: max_features, random_seed: random_seed)
55
- check_params_integer(n_estimators: n_estimators, min_samples_leaf: min_samples_leaf)
56
- check_params_string(criterion: criterion)
57
- check_params_positive(n_estimators: n_estimators, max_depth: max_depth,
58
- max_leaf_nodes: max_leaf_nodes, min_samples_leaf: min_samples_leaf,
59
- max_features: max_features)
60
- @params = {}
61
- @params[:n_estimators] = n_estimators
62
- @params[:criterion] = criterion
63
- @params[:max_depth] = max_depth
64
- @params[:max_leaf_nodes] = max_leaf_nodes
65
- @params[:min_samples_leaf] = min_samples_leaf
66
- @params[:max_features] = max_features
67
- @params[:random_seed] = random_seed
68
- @params[:random_seed] ||= srand
69
- @estimators = nil
70
- @feature_importances = nil
71
- @rng = Random.new(@params[:random_seed])
72
- end
73
-
74
- # Fit the model with given training data.
75
- #
76
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
77
- # @param y [Numo::DFloat] (shape: [n_samples, n_outputs]) The target values to be used for fitting the model.
78
- # @return [RandomForestRegressor] The learned regressor itself.
79
- def fit(x, y)
80
- check_sample_array(x)
81
- check_tvalue_array(y)
82
- check_sample_tvalue_size(x, y)
83
- # Initialize some variables.
84
- n_samples, n_features = x.shape
85
- @params[:max_features] = Math.sqrt(n_features).to_i unless @params[:max_features].is_a?(Integer)
86
- @params[:max_features] = [[1, @params[:max_features]].max, n_features].min
87
- @feature_importances = Numo::DFloat.zeros(n_features)
88
- single_target = y.shape[1].nil?
89
- # Construct forest.
90
- @estimators = Array.new(@params[:n_estimators]) do
91
- tree = Tree::DecisionTreeRegressor.new(
92
- criterion: @params[:criterion], max_depth: @params[:max_depth],
93
- max_leaf_nodes: @params[:max_leaf_nodes], min_samples_leaf: @params[:min_samples_leaf],
94
- max_features: @params[:max_features], random_seed: @rng.rand(SVMKit::Values.int_max)
95
- )
96
- bootstrap_ids = Array.new(n_samples) { @rng.rand(0...n_samples) }
97
- tree.fit(x[bootstrap_ids, true], single_target ? y[bootstrap_ids] : y[bootstrap_ids, true])
98
- @feature_importances += tree.feature_importances
99
- tree
100
- end
101
- @feature_importances /= @feature_importances.sum
102
- self
103
- end
104
-
105
- # Predict values for samples.
106
- #
107
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the values.
108
- # @return [Numo::DFloat] (shape: [n_samples, n_outputs]) Predicted value per sample.
109
- def predict(x)
110
- check_sample_array(x)
111
- @estimators.map { |est| est.predict(x) }.reduce(&:+) / @params[:n_estimators]
112
- end
113
-
114
- # Return the index of the leaf that each sample reached.
115
- #
116
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to assign each leaf.
117
- # @return [Numo::Int32] (shape: [n_samples, n_estimators]) Leaf index for sample.
118
- def apply(x)
119
- SVMKit::Validation.check_sample_array(x)
120
- Numo::Int32[*Array.new(@params[:n_estimators]) { |n| @estimators[n].apply(x) }].transpose
121
- end
122
-
123
- # Dump marshal data.
124
- # @return [Hash] The marshal data about RandomForestRegressor.
125
- def marshal_dump
126
- { params: @params,
127
- estimators: @estimators,
128
- feature_importances: @feature_importances,
129
- rng: @rng }
130
- end
131
-
132
- # Load marshal data.
133
- # @return [nil]
134
- def marshal_load(obj)
135
- @params = obj[:params]
136
- @estimators = obj[:estimators]
137
- @feature_importances = obj[:feature_importances]
138
- @rng = obj[:rng]
139
- nil
140
- end
141
- end
142
- end
143
- end
@@ -1,30 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'svmkit/validation'
4
- require 'svmkit/base/evaluator'
5
-
6
- module SVMKit
7
- # This module consists of the classes for model evaluation.
8
- module EvaluationMeasure
9
- # Accuracy is a class that calculates the accuracy of classifier from the predicted labels.
10
- #
11
- # @example
12
- # evaluator = SVMKit::EvaluationMeasure::Accuracy.new
13
- # puts evaluator.score(ground_truth, predicted)
14
- class Accuracy
15
- include Base::Evaluator
16
-
17
- # Calculate mean accuracy.
18
- #
19
- # @param y_true [Numo::Int32] (shape: [n_samples]) Ground truth labels.
20
- # @param y_pred [Numo::Int32] (shape: [n_samples]) Predicted labels.
21
- # @return [Float] Mean accuracy
22
- def score(y_true, y_pred)
23
- SVMKit::Validation.check_label_array(y_true)
24
- SVMKit::Validation.check_label_array(y_pred)
25
-
26
- (y_true.to_a.map.with_index { |label, n| label == y_pred[n] ? 1 : 0 }).inject(:+) / y_true.size.to_f
27
- end
28
- end
29
- end
30
- end
@@ -1,51 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'svmkit/validation'
4
- require 'svmkit/base/evaluator'
5
- require 'svmkit/evaluation_measure/precision_recall'
6
-
7
- module SVMKit
8
- # This module consists of the classes for model evaluation.
9
- module EvaluationMeasure
10
- # FScore is a class that calculates the F1-score of the predicted labels.
11
- #
12
- # @example
13
- # evaluator = SVMKit::EvaluationMeasure::FScore.new
14
- # puts evaluator.score(ground_truth, predicted)
15
- class FScore
16
- include Base::Evaluator
17
- include EvaluationMeasure::PrecisionRecall
18
-
19
- # Return the average type for calculation of F1-score.
20
- # @return [String] ('binary', 'micro', 'macro')
21
- attr_reader :average
22
-
23
- # Create a new evaluation measure calculater for F1-score.
24
- #
25
- # @param average [String] The average type ('binary', 'micro', 'macro')
26
- def initialize(average: 'binary')
27
- SVMKit::Validation.check_params_string(average: average)
28
- @average = average
29
- end
30
-
31
- # Calculate average F1-score
32
- #
33
- # @param y_true [Numo::Int32] (shape: [n_samples]) Ground truth labels.
34
- # @param y_pred [Numo::Int32] (shape: [n_samples]) Predicted labels.
35
- # @return [Float] Average F1-score
36
- def score(y_true, y_pred)
37
- SVMKit::Validation.check_label_array(y_true)
38
- SVMKit::Validation.check_label_array(y_pred)
39
-
40
- case @average
41
- when 'binary'
42
- f_score_each_class(y_true, y_pred).last
43
- when 'micro'
44
- micro_average_f_score(y_true, y_pred)
45
- when 'macro'
46
- macro_average_f_score(y_true, y_pred)
47
- end
48
- end
49
- end
50
- end
51
- end
@@ -1,46 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'svmkit/validation'
4
- require 'svmkit/base/evaluator'
5
- require 'svmkit/preprocessing/one_hot_encoder'
6
-
7
- module SVMKit
8
- module EvaluationMeasure
9
- # LogLoss is a class that calculates the logarithmic loss of predicted class probability.
10
- #
11
- # @example
12
- # evaluator = SVMKit::EvaluationMeasure::LogLoss.new
13
- # puts evaluator.score(ground_truth, predicted)
14
- class LogLoss
15
- include Base::Evaluator
16
-
17
- # Calculate mean logarithmic loss.
18
- # If both y_true and y_pred are array (both shapes are [n_samples]), this method calculates
19
- # mean logarithmic loss for binary classification.
20
- #
21
- # @param y_true [Numo::Int32] (shape: [n_samples]) Ground truth labels.
22
- # @param y_pred [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted class probability.
23
- # @param eps [Float] A small value close to zero to avoid outputting infinity in logarithmic calcuation.
24
- # @return [Float] mean logarithmic loss
25
- def score(y_true, y_pred, eps = 1e-15)
26
- SVMKit::Validation.check_params_type(Numo::Int32, y_true: y_true)
27
- SVMKit::Validation.check_params_type(Numo::DFloat, y_pred: y_pred)
28
-
29
- n_samples, n_classes = y_pred.shape
30
- clipped_p = y_pred.clip(eps, 1 - eps)
31
-
32
- log_loss = if n_classes.nil?
33
- negative_label = y_true.to_a.uniq.min
34
- bin_y_true = Numo::DFloat.cast(y_true.ne(negative_label))
35
- -(bin_y_true * Numo::NMath.log(clipped_p) + (1 - bin_y_true) * Numo::NMath.log(1 - clipped_p))
36
- else
37
- encoder = SVMKit::Preprocessing::OneHotEncoder.new
38
- encoded_y_true = encoder.fit_transform(y_true)
39
- clipped_p /= clipped_p.sum(1).expand_dims(1)
40
- -(encoded_y_true * Numo::NMath.log(clipped_p)).sum(1)
41
- end
42
- log_loss.sum / n_samples
43
- end
44
- end
45
- end
46
- end
@@ -1,30 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'svmkit/validation'
4
- require 'svmkit/base/evaluator'
5
-
6
- module SVMKit
7
- module EvaluationMeasure
8
- # MeanAbsoluteError is a class that calculates the mean absolute error.
9
- #
10
- # @example
11
- # evaluator = SVMKit::EvaluationMeasure::MeanAbsoluteError.new
12
- # puts evaluator.score(ground_truth, predicted)
13
- class MeanAbsoluteError
14
- include Base::Evaluator
15
-
16
- # Calculate mean absolute error.
17
- #
18
- # @param y_true [Numo::DFloat] (shape: [n_samples, n_outputs]) Ground truth target values.
19
- # @param y_pred [Numo::DFloat] (shape: [n_samples, n_outputs]) Estimated target values.
20
- # @return [Float] Mean absolute error
21
- def score(y_true, y_pred)
22
- SVMKit::Validation.check_tvalue_array(y_true)
23
- SVMKit::Validation.check_tvalue_array(y_pred)
24
- raise ArgumentError, 'Expect to have the same size both y_true and y_pred.' unless y_true.shape == y_pred.shape
25
-
26
- (y_true - y_pred).abs.mean
27
- end
28
- end
29
- end
30
- end
@@ -1,30 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'svmkit/validation'
4
- require 'svmkit/base/evaluator'
5
-
6
- module SVMKit
7
- module EvaluationMeasure
8
- # MeanSquaredError is a class that calculates the mean squared error.
9
- #
10
- # @example
11
- # evaluator = SVMKit::EvaluationMeasure::MeanSquaredError.new
12
- # puts evaluator.score(ground_truth, predicted)
13
- class MeanSquaredError
14
- include Base::Evaluator
15
-
16
- # Calculate mean squared error.
17
- #
18
- # @param y_true [Numo::DFloat] (shape: [n_samples, n_outputs]) Ground truth target values.
19
- # @param y_pred [Numo::DFloat] (shape: [n_samples, n_outputs]) Estimated target values.
20
- # @return [Float] Mean squared error
21
- def score(y_true, y_pred)
22
- SVMKit::Validation.check_tvalue_array(y_true)
23
- SVMKit::Validation.check_tvalue_array(y_pred)
24
- raise ArgumentError, 'Expect to have the same size both y_true and y_pred.' unless y_true.shape == y_pred.shape
25
-
26
- ((y_true - y_pred)**2).mean
27
- end
28
- end
29
- end
30
- end
@@ -1,63 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'svmkit/validation'
4
- require 'svmkit/base/evaluator'
5
-
6
- module SVMKit
7
- module EvaluationMeasure
8
- # NormalizedMutualInformation is a class that calculates the normalized mutual information of cluatering results.
9
- #
10
- # @example
11
- # evaluator = SVMKit::EvaluationMeasure::NormalizedMutualInformation.new
12
- # puts evaluator.score(ground_truth, predicted)
13
- #
14
- # *Reference*
15
- # - C D. Manning, P. Raghavan, and H. Schutze, "Introduction to Information Retrieval," Cambridge University Press., 2008.
16
- # - N X. Vinh, J. Epps, and J. Bailey, "Information Theoretic Measures for Clusterings Comparison: Variants, Properties, Normalization and Correction for Chance," J. Machine Learning Research, vol. 11, pp. 2837--1854, 2010.
17
- class NormalizedMutualInformation
18
- include Base::Evaluator
19
-
20
- # Calculate noramlzied mutual information
21
- #
22
- # @param y_true [Numo::Int32] (shape: [n_samples]) Ground truth labels.
23
- # @param y_pred [Numo::Int32] (shape: [n_samples]) Predicted cluster labels.
24
- # @return [Float] Normalized mutual information
25
- def score(y_true, y_pred)
26
- SVMKit::Validation.check_label_array(y_true)
27
- SVMKit::Validation.check_label_array(y_pred)
28
- # initiazlie some variables.
29
- mutual_information = 0.0
30
- n_samples = y_pred.size
31
- class_ids = y_true.to_a.uniq
32
- cluster_ids = y_pred.to_a.uniq
33
- # calculate entropy.
34
- class_entropy = -1.0 * class_ids.map do |k|
35
- ratio = y_true.eq(k).count.fdiv(n_samples)
36
- ratio * Math.log(ratio)
37
- end.reduce(:+)
38
- return 0.0 if class_entropy.zero?
39
- cluster_entropy = -1.0 * cluster_ids.map do |k|
40
- ratio = y_pred.eq(k).count.fdiv(n_samples)
41
- ratio * Math.log(ratio)
42
- end.reduce(:+)
43
- return 0.0 if cluster_entropy.zero?
44
- # calculate mutual information.
45
- cluster_ids.map do |k|
46
- pr_sample_ids = y_pred.eq(k).where.to_a
47
- n_pr_samples = pr_sample_ids.size
48
- class_ids.map do |j|
49
- tr_sample_ids = y_true.eq(j).where.to_a
50
- n_tr_samples = tr_sample_ids.size
51
- n_intr_samples = (pr_sample_ids & tr_sample_ids).size
52
- if n_intr_samples > 0
53
- mutual_information +=
54
- n_intr_samples.fdiv(n_samples) * Math.log((n_samples * n_intr_samples).fdiv(n_pr_samples * n_tr_samples))
55
- end
56
- end
57
- end
58
- # return normalized mutual information.
59
- mutual_information / Math.sqrt(class_entropy * cluster_entropy)
60
- end
61
- end
62
- end
63
- end
@@ -1,51 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'svmkit/validation'
4
- require 'svmkit/base/evaluator'
5
- require 'svmkit/evaluation_measure/precision_recall'
6
-
7
- module SVMKit
8
- # This module consists of the classes for model evaluation.
9
- module EvaluationMeasure
10
- # Precision is a class that calculates the preicision of the predicted labels.
11
- #
12
- # @example
13
- # evaluator = SVMKit::EvaluationMeasure::Precision.new
14
- # puts evaluator.score(ground_truth, predicted)
15
- class Precision
16
- include Base::Evaluator
17
- include EvaluationMeasure::PrecisionRecall
18
-
19
- # Return the average type for calculation of precision.
20
- # @return [String] ('binary', 'micro', 'macro')
21
- attr_reader :average
22
-
23
- # Create a new evaluation measure calculater for precision score.
24
- #
25
- # @param average [String] The average type ('binary', 'micro', 'macro')
26
- def initialize(average: 'binary')
27
- SVMKit::Validation.check_params_string(average: average)
28
- @average = average
29
- end
30
-
31
- # Calculate average precision.
32
- #
33
- # @param y_true [Numo::Int32] (shape: [n_samples]) Ground truth labels.
34
- # @param y_pred [Numo::Int32] (shape: [n_samples]) Predicted labels.
35
- # @return [Float] Average precision
36
- def score(y_true, y_pred)
37
- SVMKit::Validation.check_label_array(y_true)
38
- SVMKit::Validation.check_label_array(y_pred)
39
-
40
- case @average
41
- when 'binary'
42
- precision_each_class(y_true, y_pred).last
43
- when 'micro'
44
- micro_average_precision(y_true, y_pred)
45
- when 'macro'
46
- macro_average_precision(y_true, y_pred)
47
- end
48
- end
49
- end
50
- end
51
- end