svmkit 0.7.3 → 0.8.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +0 -9
  3. data/.rspec +1 -0
  4. data/.travis.yml +4 -12
  5. data/LICENSE.txt +1 -1
  6. data/README.md +11 -13
  7. data/lib/svmkit.rb +3 -66
  8. data/svmkit.gemspec +12 -7
  9. metadata +16 -81
  10. data/.coveralls.yml +0 -1
  11. data/.rubocop.yml +0 -47
  12. data/.rubocop_todo.yml +0 -58
  13. data/HISTORY.md +0 -168
  14. data/lib/svmkit/base/base_estimator.rb +0 -13
  15. data/lib/svmkit/base/classifier.rb +0 -34
  16. data/lib/svmkit/base/cluster_analyzer.rb +0 -29
  17. data/lib/svmkit/base/evaluator.rb +0 -13
  18. data/lib/svmkit/base/regressor.rb +0 -34
  19. data/lib/svmkit/base/splitter.rb +0 -17
  20. data/lib/svmkit/base/transformer.rb +0 -18
  21. data/lib/svmkit/clustering/dbscan.rb +0 -127
  22. data/lib/svmkit/clustering/k_means.rb +0 -140
  23. data/lib/svmkit/dataset.rb +0 -109
  24. data/lib/svmkit/decomposition/nmf.rb +0 -147
  25. data/lib/svmkit/decomposition/pca.rb +0 -150
  26. data/lib/svmkit/ensemble/ada_boost_classifier.rb +0 -198
  27. data/lib/svmkit/ensemble/ada_boost_regressor.rb +0 -180
  28. data/lib/svmkit/ensemble/random_forest_classifier.rb +0 -182
  29. data/lib/svmkit/ensemble/random_forest_regressor.rb +0 -143
  30. data/lib/svmkit/evaluation_measure/accuracy.rb +0 -30
  31. data/lib/svmkit/evaluation_measure/f_score.rb +0 -51
  32. data/lib/svmkit/evaluation_measure/log_loss.rb +0 -46
  33. data/lib/svmkit/evaluation_measure/mean_absolute_error.rb +0 -30
  34. data/lib/svmkit/evaluation_measure/mean_squared_error.rb +0 -30
  35. data/lib/svmkit/evaluation_measure/normalized_mutual_information.rb +0 -63
  36. data/lib/svmkit/evaluation_measure/precision.rb +0 -51
  37. data/lib/svmkit/evaluation_measure/precision_recall.rb +0 -91
  38. data/lib/svmkit/evaluation_measure/purity.rb +0 -41
  39. data/lib/svmkit/evaluation_measure/r2_score.rb +0 -44
  40. data/lib/svmkit/evaluation_measure/recall.rb +0 -51
  41. data/lib/svmkit/kernel_approximation/rbf.rb +0 -136
  42. data/lib/svmkit/kernel_machine/kernel_svc.rb +0 -194
  43. data/lib/svmkit/linear_model/lasso.rb +0 -138
  44. data/lib/svmkit/linear_model/linear_regression.rb +0 -112
  45. data/lib/svmkit/linear_model/logistic_regression.rb +0 -161
  46. data/lib/svmkit/linear_model/ridge.rb +0 -112
  47. data/lib/svmkit/linear_model/sgd_linear_estimator.rb +0 -89
  48. data/lib/svmkit/linear_model/svc.rb +0 -184
  49. data/lib/svmkit/linear_model/svr.rb +0 -123
  50. data/lib/svmkit/model_selection/cross_validation.rb +0 -121
  51. data/lib/svmkit/model_selection/grid_search_cv.rb +0 -247
  52. data/lib/svmkit/model_selection/k_fold.rb +0 -77
  53. data/lib/svmkit/model_selection/stratified_k_fold.rb +0 -95
  54. data/lib/svmkit/multiclass/one_vs_rest_classifier.rb +0 -101
  55. data/lib/svmkit/naive_bayes/naive_bayes.rb +0 -316
  56. data/lib/svmkit/nearest_neighbors/k_neighbors_classifier.rb +0 -112
  57. data/lib/svmkit/nearest_neighbors/k_neighbors_regressor.rb +0 -94
  58. data/lib/svmkit/optimizer/nadam.rb +0 -90
  59. data/lib/svmkit/optimizer/rmsprop.rb +0 -69
  60. data/lib/svmkit/optimizer/sgd.rb +0 -65
  61. data/lib/svmkit/optimizer/yellow_fin.rb +0 -144
  62. data/lib/svmkit/pairwise_metric.rb +0 -91
  63. data/lib/svmkit/pipeline/pipeline.rb +0 -197
  64. data/lib/svmkit/polynomial_model/factorization_machine_classifier.rb +0 -262
  65. data/lib/svmkit/polynomial_model/factorization_machine_regressor.rb +0 -194
  66. data/lib/svmkit/preprocessing/l2_normalizer.rb +0 -63
  67. data/lib/svmkit/preprocessing/label_encoder.rb +0 -95
  68. data/lib/svmkit/preprocessing/min_max_scaler.rb +0 -93
  69. data/lib/svmkit/preprocessing/one_hot_encoder.rb +0 -99
  70. data/lib/svmkit/preprocessing/standard_scaler.rb +0 -87
  71. data/lib/svmkit/probabilistic_output.rb +0 -112
  72. data/lib/svmkit/tree/decision_tree_classifier.rb +0 -276
  73. data/lib/svmkit/tree/decision_tree_regressor.rb +0 -251
  74. data/lib/svmkit/tree/node.rb +0 -70
  75. data/lib/svmkit/utils.rb +0 -22
  76. data/lib/svmkit/validation.rb +0 -79
  77. data/lib/svmkit/values.rb +0 -13
  78. data/lib/svmkit/version.rb +0 -7
@@ -1,143 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'svmkit/validation'
4
- require 'svmkit/values'
5
- require 'svmkit/base/base_estimator'
6
- require 'svmkit/base/regressor'
7
- require 'svmkit/tree/decision_tree_regressor'
8
-
9
- module SVMKit
10
- module Ensemble
11
- # RandomForestRegressor is a class that implements random forest for regression
12
- #
13
- # @example
14
- # estimator =
15
- # SVMKit::Ensemble::RandomForestRegressor.new(
16
- # n_estimators: 10, criterion: 'mse', max_depth: 3, max_leaf_nodes: 10, min_samples_leaf: 5, random_seed: 1)
17
- # estimator.fit(training_samples, traininig_values)
18
- # results = estimator.predict(testing_samples)
19
- #
20
- class RandomForestRegressor
21
- include Base::BaseEstimator
22
- include Base::Regressor
23
- include Validation
24
-
25
- # Return the set of estimators.
26
- # @return [Array<DecisionTreeRegressor>]
27
- attr_reader :estimators
28
-
29
- # Return the importance for each feature.
30
- # @return [Numo::DFloat] (size: n_features)
31
- attr_reader :feature_importances
32
-
33
- # Return the random generator for random selection of feature index.
34
- # @return [Random]
35
- attr_reader :rng
36
-
37
- # Create a new regressor with random forest.
38
- #
39
- # @param n_estimators [Integer] The numeber of decision trees for contructing random forest.
40
- # @param criterion [String] The function to evalue spliting point. Supported criteria are 'gini' and 'entropy'.
41
- # @param max_depth [Integer] The maximum depth of the tree.
42
- # If nil is given, decision tree grows without concern for depth.
43
- # @param max_leaf_nodes [Integer] The maximum number of leaves on decision tree.
44
- # If nil is given, number of leaves is not limited.
45
- # @param min_samples_leaf [Integer] The minimum number of samples at a leaf node.
46
- # @param max_features [Integer] The number of features to consider when searching optimal split point.
47
- # If nil is given, split process considers all features.
48
- # @param random_seed [Integer] The seed value using to initialize the random generator.
49
- # It is used to randomly determine the order of features when deciding spliting point.
50
- def initialize(n_estimators: 10,
51
- criterion: 'mse', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1,
52
- max_features: nil, random_seed: nil)
53
- check_params_type_or_nil(Integer, max_depth: max_depth, max_leaf_nodes: max_leaf_nodes,
54
- max_features: max_features, random_seed: random_seed)
55
- check_params_integer(n_estimators: n_estimators, min_samples_leaf: min_samples_leaf)
56
- check_params_string(criterion: criterion)
57
- check_params_positive(n_estimators: n_estimators, max_depth: max_depth,
58
- max_leaf_nodes: max_leaf_nodes, min_samples_leaf: min_samples_leaf,
59
- max_features: max_features)
60
- @params = {}
61
- @params[:n_estimators] = n_estimators
62
- @params[:criterion] = criterion
63
- @params[:max_depth] = max_depth
64
- @params[:max_leaf_nodes] = max_leaf_nodes
65
- @params[:min_samples_leaf] = min_samples_leaf
66
- @params[:max_features] = max_features
67
- @params[:random_seed] = random_seed
68
- @params[:random_seed] ||= srand
69
- @estimators = nil
70
- @feature_importances = nil
71
- @rng = Random.new(@params[:random_seed])
72
- end
73
-
74
- # Fit the model with given training data.
75
- #
76
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
77
- # @param y [Numo::DFloat] (shape: [n_samples, n_outputs]) The target values to be used for fitting the model.
78
- # @return [RandomForestRegressor] The learned regressor itself.
79
- def fit(x, y)
80
- check_sample_array(x)
81
- check_tvalue_array(y)
82
- check_sample_tvalue_size(x, y)
83
- # Initialize some variables.
84
- n_samples, n_features = x.shape
85
- @params[:max_features] = Math.sqrt(n_features).to_i unless @params[:max_features].is_a?(Integer)
86
- @params[:max_features] = [[1, @params[:max_features]].max, n_features].min
87
- @feature_importances = Numo::DFloat.zeros(n_features)
88
- single_target = y.shape[1].nil?
89
- # Construct forest.
90
- @estimators = Array.new(@params[:n_estimators]) do
91
- tree = Tree::DecisionTreeRegressor.new(
92
- criterion: @params[:criterion], max_depth: @params[:max_depth],
93
- max_leaf_nodes: @params[:max_leaf_nodes], min_samples_leaf: @params[:min_samples_leaf],
94
- max_features: @params[:max_features], random_seed: @rng.rand(SVMKit::Values.int_max)
95
- )
96
- bootstrap_ids = Array.new(n_samples) { @rng.rand(0...n_samples) }
97
- tree.fit(x[bootstrap_ids, true], single_target ? y[bootstrap_ids] : y[bootstrap_ids, true])
98
- @feature_importances += tree.feature_importances
99
- tree
100
- end
101
- @feature_importances /= @feature_importances.sum
102
- self
103
- end
104
-
105
- # Predict values for samples.
106
- #
107
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the values.
108
- # @return [Numo::DFloat] (shape: [n_samples, n_outputs]) Predicted value per sample.
109
- def predict(x)
110
- check_sample_array(x)
111
- @estimators.map { |est| est.predict(x) }.reduce(&:+) / @params[:n_estimators]
112
- end
113
-
114
- # Return the index of the leaf that each sample reached.
115
- #
116
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to assign each leaf.
117
- # @return [Numo::Int32] (shape: [n_samples, n_estimators]) Leaf index for sample.
118
- def apply(x)
119
- SVMKit::Validation.check_sample_array(x)
120
- Numo::Int32[*Array.new(@params[:n_estimators]) { |n| @estimators[n].apply(x) }].transpose
121
- end
122
-
123
- # Dump marshal data.
124
- # @return [Hash] The marshal data about RandomForestRegressor.
125
- def marshal_dump
126
- { params: @params,
127
- estimators: @estimators,
128
- feature_importances: @feature_importances,
129
- rng: @rng }
130
- end
131
-
132
- # Load marshal data.
133
- # @return [nil]
134
- def marshal_load(obj)
135
- @params = obj[:params]
136
- @estimators = obj[:estimators]
137
- @feature_importances = obj[:feature_importances]
138
- @rng = obj[:rng]
139
- nil
140
- end
141
- end
142
- end
143
- end
@@ -1,30 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'svmkit/validation'
4
- require 'svmkit/base/evaluator'
5
-
6
- module SVMKit
7
- # This module consists of the classes for model evaluation.
8
- module EvaluationMeasure
9
- # Accuracy is a class that calculates the accuracy of classifier from the predicted labels.
10
- #
11
- # @example
12
- # evaluator = SVMKit::EvaluationMeasure::Accuracy.new
13
- # puts evaluator.score(ground_truth, predicted)
14
- class Accuracy
15
- include Base::Evaluator
16
-
17
- # Calculate mean accuracy.
18
- #
19
- # @param y_true [Numo::Int32] (shape: [n_samples]) Ground truth labels.
20
- # @param y_pred [Numo::Int32] (shape: [n_samples]) Predicted labels.
21
- # @return [Float] Mean accuracy
22
- def score(y_true, y_pred)
23
- SVMKit::Validation.check_label_array(y_true)
24
- SVMKit::Validation.check_label_array(y_pred)
25
-
26
- (y_true.to_a.map.with_index { |label, n| label == y_pred[n] ? 1 : 0 }).inject(:+) / y_true.size.to_f
27
- end
28
- end
29
- end
30
- end
@@ -1,51 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'svmkit/validation'
4
- require 'svmkit/base/evaluator'
5
- require 'svmkit/evaluation_measure/precision_recall'
6
-
7
- module SVMKit
8
- # This module consists of the classes for model evaluation.
9
- module EvaluationMeasure
10
- # FScore is a class that calculates the F1-score of the predicted labels.
11
- #
12
- # @example
13
- # evaluator = SVMKit::EvaluationMeasure::FScore.new
14
- # puts evaluator.score(ground_truth, predicted)
15
- class FScore
16
- include Base::Evaluator
17
- include EvaluationMeasure::PrecisionRecall
18
-
19
- # Return the average type for calculation of F1-score.
20
- # @return [String] ('binary', 'micro', 'macro')
21
- attr_reader :average
22
-
23
- # Create a new evaluation measure calculater for F1-score.
24
- #
25
- # @param average [String] The average type ('binary', 'micro', 'macro')
26
- def initialize(average: 'binary')
27
- SVMKit::Validation.check_params_string(average: average)
28
- @average = average
29
- end
30
-
31
- # Calculate average F1-score
32
- #
33
- # @param y_true [Numo::Int32] (shape: [n_samples]) Ground truth labels.
34
- # @param y_pred [Numo::Int32] (shape: [n_samples]) Predicted labels.
35
- # @return [Float] Average F1-score
36
- def score(y_true, y_pred)
37
- SVMKit::Validation.check_label_array(y_true)
38
- SVMKit::Validation.check_label_array(y_pred)
39
-
40
- case @average
41
- when 'binary'
42
- f_score_each_class(y_true, y_pred).last
43
- when 'micro'
44
- micro_average_f_score(y_true, y_pred)
45
- when 'macro'
46
- macro_average_f_score(y_true, y_pred)
47
- end
48
- end
49
- end
50
- end
51
- end
@@ -1,46 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'svmkit/validation'
4
- require 'svmkit/base/evaluator'
5
- require 'svmkit/preprocessing/one_hot_encoder'
6
-
7
- module SVMKit
8
- module EvaluationMeasure
9
- # LogLoss is a class that calculates the logarithmic loss of predicted class probability.
10
- #
11
- # @example
12
- # evaluator = SVMKit::EvaluationMeasure::LogLoss.new
13
- # puts evaluator.score(ground_truth, predicted)
14
- class LogLoss
15
- include Base::Evaluator
16
-
17
- # Calculate mean logarithmic loss.
18
- # If both y_true and y_pred are array (both shapes are [n_samples]), this method calculates
19
- # mean logarithmic loss for binary classification.
20
- #
21
- # @param y_true [Numo::Int32] (shape: [n_samples]) Ground truth labels.
22
- # @param y_pred [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted class probability.
23
- # @param eps [Float] A small value close to zero to avoid outputting infinity in logarithmic calcuation.
24
- # @return [Float] mean logarithmic loss
25
- def score(y_true, y_pred, eps = 1e-15)
26
- SVMKit::Validation.check_params_type(Numo::Int32, y_true: y_true)
27
- SVMKit::Validation.check_params_type(Numo::DFloat, y_pred: y_pred)
28
-
29
- n_samples, n_classes = y_pred.shape
30
- clipped_p = y_pred.clip(eps, 1 - eps)
31
-
32
- log_loss = if n_classes.nil?
33
- negative_label = y_true.to_a.uniq.min
34
- bin_y_true = Numo::DFloat.cast(y_true.ne(negative_label))
35
- -(bin_y_true * Numo::NMath.log(clipped_p) + (1 - bin_y_true) * Numo::NMath.log(1 - clipped_p))
36
- else
37
- encoder = SVMKit::Preprocessing::OneHotEncoder.new
38
- encoded_y_true = encoder.fit_transform(y_true)
39
- clipped_p /= clipped_p.sum(1).expand_dims(1)
40
- -(encoded_y_true * Numo::NMath.log(clipped_p)).sum(1)
41
- end
42
- log_loss.sum / n_samples
43
- end
44
- end
45
- end
46
- end
@@ -1,30 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'svmkit/validation'
4
- require 'svmkit/base/evaluator'
5
-
6
- module SVMKit
7
- module EvaluationMeasure
8
- # MeanAbsoluteError is a class that calculates the mean absolute error.
9
- #
10
- # @example
11
- # evaluator = SVMKit::EvaluationMeasure::MeanAbsoluteError.new
12
- # puts evaluator.score(ground_truth, predicted)
13
- class MeanAbsoluteError
14
- include Base::Evaluator
15
-
16
- # Calculate mean absolute error.
17
- #
18
- # @param y_true [Numo::DFloat] (shape: [n_samples, n_outputs]) Ground truth target values.
19
- # @param y_pred [Numo::DFloat] (shape: [n_samples, n_outputs]) Estimated target values.
20
- # @return [Float] Mean absolute error
21
- def score(y_true, y_pred)
22
- SVMKit::Validation.check_tvalue_array(y_true)
23
- SVMKit::Validation.check_tvalue_array(y_pred)
24
- raise ArgumentError, 'Expect to have the same size both y_true and y_pred.' unless y_true.shape == y_pred.shape
25
-
26
- (y_true - y_pred).abs.mean
27
- end
28
- end
29
- end
30
- end
@@ -1,30 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'svmkit/validation'
4
- require 'svmkit/base/evaluator'
5
-
6
- module SVMKit
7
- module EvaluationMeasure
8
- # MeanSquaredError is a class that calculates the mean squared error.
9
- #
10
- # @example
11
- # evaluator = SVMKit::EvaluationMeasure::MeanSquaredError.new
12
- # puts evaluator.score(ground_truth, predicted)
13
- class MeanSquaredError
14
- include Base::Evaluator
15
-
16
- # Calculate mean squared error.
17
- #
18
- # @param y_true [Numo::DFloat] (shape: [n_samples, n_outputs]) Ground truth target values.
19
- # @param y_pred [Numo::DFloat] (shape: [n_samples, n_outputs]) Estimated target values.
20
- # @return [Float] Mean squared error
21
- def score(y_true, y_pred)
22
- SVMKit::Validation.check_tvalue_array(y_true)
23
- SVMKit::Validation.check_tvalue_array(y_pred)
24
- raise ArgumentError, 'Expect to have the same size both y_true and y_pred.' unless y_true.shape == y_pred.shape
25
-
26
- ((y_true - y_pred)**2).mean
27
- end
28
- end
29
- end
30
- end
@@ -1,63 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'svmkit/validation'
4
- require 'svmkit/base/evaluator'
5
-
6
- module SVMKit
7
- module EvaluationMeasure
8
- # NormalizedMutualInformation is a class that calculates the normalized mutual information of cluatering results.
9
- #
10
- # @example
11
- # evaluator = SVMKit::EvaluationMeasure::NormalizedMutualInformation.new
12
- # puts evaluator.score(ground_truth, predicted)
13
- #
14
- # *Reference*
15
- # - C D. Manning, P. Raghavan, and H. Schutze, "Introduction to Information Retrieval," Cambridge University Press., 2008.
16
- # - N X. Vinh, J. Epps, and J. Bailey, "Information Theoretic Measures for Clusterings Comparison: Variants, Properties, Normalization and Correction for Chance," J. Machine Learning Research, vol. 11, pp. 2837--1854, 2010.
17
- class NormalizedMutualInformation
18
- include Base::Evaluator
19
-
20
- # Calculate noramlzied mutual information
21
- #
22
- # @param y_true [Numo::Int32] (shape: [n_samples]) Ground truth labels.
23
- # @param y_pred [Numo::Int32] (shape: [n_samples]) Predicted cluster labels.
24
- # @return [Float] Normalized mutual information
25
- def score(y_true, y_pred)
26
- SVMKit::Validation.check_label_array(y_true)
27
- SVMKit::Validation.check_label_array(y_pred)
28
- # initiazlie some variables.
29
- mutual_information = 0.0
30
- n_samples = y_pred.size
31
- class_ids = y_true.to_a.uniq
32
- cluster_ids = y_pred.to_a.uniq
33
- # calculate entropy.
34
- class_entropy = -1.0 * class_ids.map do |k|
35
- ratio = y_true.eq(k).count.fdiv(n_samples)
36
- ratio * Math.log(ratio)
37
- end.reduce(:+)
38
- return 0.0 if class_entropy.zero?
39
- cluster_entropy = -1.0 * cluster_ids.map do |k|
40
- ratio = y_pred.eq(k).count.fdiv(n_samples)
41
- ratio * Math.log(ratio)
42
- end.reduce(:+)
43
- return 0.0 if cluster_entropy.zero?
44
- # calculate mutual information.
45
- cluster_ids.map do |k|
46
- pr_sample_ids = y_pred.eq(k).where.to_a
47
- n_pr_samples = pr_sample_ids.size
48
- class_ids.map do |j|
49
- tr_sample_ids = y_true.eq(j).where.to_a
50
- n_tr_samples = tr_sample_ids.size
51
- n_intr_samples = (pr_sample_ids & tr_sample_ids).size
52
- if n_intr_samples > 0
53
- mutual_information +=
54
- n_intr_samples.fdiv(n_samples) * Math.log((n_samples * n_intr_samples).fdiv(n_pr_samples * n_tr_samples))
55
- end
56
- end
57
- end
58
- # return normalized mutual information.
59
- mutual_information / Math.sqrt(class_entropy * cluster_entropy)
60
- end
61
- end
62
- end
63
- end
@@ -1,51 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'svmkit/validation'
4
- require 'svmkit/base/evaluator'
5
- require 'svmkit/evaluation_measure/precision_recall'
6
-
7
- module SVMKit
8
- # This module consists of the classes for model evaluation.
9
- module EvaluationMeasure
10
- # Precision is a class that calculates the preicision of the predicted labels.
11
- #
12
- # @example
13
- # evaluator = SVMKit::EvaluationMeasure::Precision.new
14
- # puts evaluator.score(ground_truth, predicted)
15
- class Precision
16
- include Base::Evaluator
17
- include EvaluationMeasure::PrecisionRecall
18
-
19
- # Return the average type for calculation of precision.
20
- # @return [String] ('binary', 'micro', 'macro')
21
- attr_reader :average
22
-
23
- # Create a new evaluation measure calculater for precision score.
24
- #
25
- # @param average [String] The average type ('binary', 'micro', 'macro')
26
- def initialize(average: 'binary')
27
- SVMKit::Validation.check_params_string(average: average)
28
- @average = average
29
- end
30
-
31
- # Calculate average precision.
32
- #
33
- # @param y_true [Numo::Int32] (shape: [n_samples]) Ground truth labels.
34
- # @param y_pred [Numo::Int32] (shape: [n_samples]) Predicted labels.
35
- # @return [Float] Average precision
36
- def score(y_true, y_pred)
37
- SVMKit::Validation.check_label_array(y_true)
38
- SVMKit::Validation.check_label_array(y_pred)
39
-
40
- case @average
41
- when 'binary'
42
- precision_each_class(y_true, y_pred).last
43
- when 'micro'
44
- micro_average_precision(y_true, y_pred)
45
- when 'macro'
46
- macro_average_precision(y_true, y_pred)
47
- end
48
- end
49
- end
50
- end
51
- end