rumale 0.23.3 → 0.24.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/LICENSE.txt +5 -1
- data/README.md +3 -288
- data/lib/rumale/version.rb +1 -1
- data/lib/rumale.rb +20 -131
- metadata +252 -150
- data/CHANGELOG.md +0 -643
- data/CODE_OF_CONDUCT.md +0 -74
- data/ext/rumale/extconf.rb +0 -37
- data/ext/rumale/rumaleext.c +0 -545
- data/ext/rumale/rumaleext.h +0 -12
- data/lib/rumale/base/base_estimator.rb +0 -49
- data/lib/rumale/base/classifier.rb +0 -36
- data/lib/rumale/base/cluster_analyzer.rb +0 -31
- data/lib/rumale/base/evaluator.rb +0 -17
- data/lib/rumale/base/regressor.rb +0 -36
- data/lib/rumale/base/splitter.rb +0 -21
- data/lib/rumale/base/transformer.rb +0 -22
- data/lib/rumale/clustering/dbscan.rb +0 -123
- data/lib/rumale/clustering/gaussian_mixture.rb +0 -218
- data/lib/rumale/clustering/hdbscan.rb +0 -291
- data/lib/rumale/clustering/k_means.rb +0 -122
- data/lib/rumale/clustering/k_medoids.rb +0 -141
- data/lib/rumale/clustering/mini_batch_k_means.rb +0 -139
- data/lib/rumale/clustering/power_iteration.rb +0 -127
- data/lib/rumale/clustering/single_linkage.rb +0 -203
- data/lib/rumale/clustering/snn.rb +0 -76
- data/lib/rumale/clustering/spectral_clustering.rb +0 -115
- data/lib/rumale/dataset.rb +0 -246
- data/lib/rumale/decomposition/factor_analysis.rb +0 -150
- data/lib/rumale/decomposition/fast_ica.rb +0 -188
- data/lib/rumale/decomposition/nmf.rb +0 -124
- data/lib/rumale/decomposition/pca.rb +0 -159
- data/lib/rumale/ensemble/ada_boost_classifier.rb +0 -179
- data/lib/rumale/ensemble/ada_boost_regressor.rb +0 -160
- data/lib/rumale/ensemble/extra_trees_classifier.rb +0 -139
- data/lib/rumale/ensemble/extra_trees_regressor.rb +0 -125
- data/lib/rumale/ensemble/gradient_boosting_classifier.rb +0 -306
- data/lib/rumale/ensemble/gradient_boosting_regressor.rb +0 -237
- data/lib/rumale/ensemble/random_forest_classifier.rb +0 -189
- data/lib/rumale/ensemble/random_forest_regressor.rb +0 -153
- data/lib/rumale/ensemble/stacking_classifier.rb +0 -215
- data/lib/rumale/ensemble/stacking_regressor.rb +0 -163
- data/lib/rumale/ensemble/voting_classifier.rb +0 -126
- data/lib/rumale/ensemble/voting_regressor.rb +0 -82
- data/lib/rumale/evaluation_measure/accuracy.rb +0 -29
- data/lib/rumale/evaluation_measure/adjusted_rand_score.rb +0 -74
- data/lib/rumale/evaluation_measure/calinski_harabasz_score.rb +0 -56
- data/lib/rumale/evaluation_measure/davies_bouldin_score.rb +0 -53
- data/lib/rumale/evaluation_measure/explained_variance_score.rb +0 -39
- data/lib/rumale/evaluation_measure/f_score.rb +0 -50
- data/lib/rumale/evaluation_measure/function.rb +0 -147
- data/lib/rumale/evaluation_measure/log_loss.rb +0 -45
- data/lib/rumale/evaluation_measure/mean_absolute_error.rb +0 -29
- data/lib/rumale/evaluation_measure/mean_squared_error.rb +0 -29
- data/lib/rumale/evaluation_measure/mean_squared_log_error.rb +0 -29
- data/lib/rumale/evaluation_measure/median_absolute_error.rb +0 -30
- data/lib/rumale/evaluation_measure/mutual_information.rb +0 -49
- data/lib/rumale/evaluation_measure/normalized_mutual_information.rb +0 -53
- data/lib/rumale/evaluation_measure/precision.rb +0 -50
- data/lib/rumale/evaluation_measure/precision_recall.rb +0 -96
- data/lib/rumale/evaluation_measure/purity.rb +0 -40
- data/lib/rumale/evaluation_measure/r2_score.rb +0 -43
- data/lib/rumale/evaluation_measure/recall.rb +0 -50
- data/lib/rumale/evaluation_measure/roc_auc.rb +0 -130
- data/lib/rumale/evaluation_measure/silhouette_score.rb +0 -82
- data/lib/rumale/feature_extraction/feature_hasher.rb +0 -110
- data/lib/rumale/feature_extraction/hash_vectorizer.rb +0 -155
- data/lib/rumale/feature_extraction/tfidf_transformer.rb +0 -113
- data/lib/rumale/kernel_approximation/nystroem.rb +0 -126
- data/lib/rumale/kernel_approximation/rbf.rb +0 -102
- data/lib/rumale/kernel_machine/kernel_fda.rb +0 -120
- data/lib/rumale/kernel_machine/kernel_pca.rb +0 -97
- data/lib/rumale/kernel_machine/kernel_ridge.rb +0 -82
- data/lib/rumale/kernel_machine/kernel_ridge_classifier.rb +0 -92
- data/lib/rumale/kernel_machine/kernel_svc.rb +0 -193
- data/lib/rumale/linear_model/base_sgd.rb +0 -285
- data/lib/rumale/linear_model/elastic_net.rb +0 -119
- data/lib/rumale/linear_model/lasso.rb +0 -115
- data/lib/rumale/linear_model/linear_regression.rb +0 -201
- data/lib/rumale/linear_model/logistic_regression.rb +0 -275
- data/lib/rumale/linear_model/nnls.rb +0 -137
- data/lib/rumale/linear_model/ridge.rb +0 -209
- data/lib/rumale/linear_model/svc.rb +0 -213
- data/lib/rumale/linear_model/svr.rb +0 -132
- data/lib/rumale/manifold/mds.rb +0 -155
- data/lib/rumale/manifold/tsne.rb +0 -222
- data/lib/rumale/metric_learning/fisher_discriminant_analysis.rb +0 -113
- data/lib/rumale/metric_learning/mlkr.rb +0 -161
- data/lib/rumale/metric_learning/neighbourhood_component_analysis.rb +0 -167
- data/lib/rumale/model_selection/cross_validation.rb +0 -125
- data/lib/rumale/model_selection/function.rb +0 -42
- data/lib/rumale/model_selection/grid_search_cv.rb +0 -225
- data/lib/rumale/model_selection/group_k_fold.rb +0 -93
- data/lib/rumale/model_selection/group_shuffle_split.rb +0 -115
- data/lib/rumale/model_selection/k_fold.rb +0 -81
- data/lib/rumale/model_selection/shuffle_split.rb +0 -90
- data/lib/rumale/model_selection/stratified_k_fold.rb +0 -99
- data/lib/rumale/model_selection/stratified_shuffle_split.rb +0 -118
- data/lib/rumale/model_selection/time_series_split.rb +0 -91
- data/lib/rumale/multiclass/one_vs_rest_classifier.rb +0 -83
- data/lib/rumale/naive_bayes/base_naive_bayes.rb +0 -47
- data/lib/rumale/naive_bayes/bernoulli_nb.rb +0 -82
- data/lib/rumale/naive_bayes/complement_nb.rb +0 -85
- data/lib/rumale/naive_bayes/gaussian_nb.rb +0 -69
- data/lib/rumale/naive_bayes/multinomial_nb.rb +0 -74
- data/lib/rumale/naive_bayes/negation_nb.rb +0 -71
- data/lib/rumale/nearest_neighbors/k_neighbors_classifier.rb +0 -133
- data/lib/rumale/nearest_neighbors/k_neighbors_regressor.rb +0 -108
- data/lib/rumale/nearest_neighbors/vp_tree.rb +0 -132
- data/lib/rumale/neural_network/adam.rb +0 -56
- data/lib/rumale/neural_network/base_mlp.rb +0 -248
- data/lib/rumale/neural_network/mlp_classifier.rb +0 -120
- data/lib/rumale/neural_network/mlp_regressor.rb +0 -90
- data/lib/rumale/pairwise_metric.rb +0 -152
- data/lib/rumale/pipeline/feature_union.rb +0 -69
- data/lib/rumale/pipeline/pipeline.rb +0 -175
- data/lib/rumale/preprocessing/bin_discretizer.rb +0 -93
- data/lib/rumale/preprocessing/binarizer.rb +0 -60
- data/lib/rumale/preprocessing/kernel_calculator.rb +0 -92
- data/lib/rumale/preprocessing/l1_normalizer.rb +0 -62
- data/lib/rumale/preprocessing/l2_normalizer.rb +0 -63
- data/lib/rumale/preprocessing/label_binarizer.rb +0 -89
- data/lib/rumale/preprocessing/label_encoder.rb +0 -79
- data/lib/rumale/preprocessing/max_abs_scaler.rb +0 -61
- data/lib/rumale/preprocessing/max_normalizer.rb +0 -62
- data/lib/rumale/preprocessing/min_max_scaler.rb +0 -76
- data/lib/rumale/preprocessing/one_hot_encoder.rb +0 -100
- data/lib/rumale/preprocessing/ordinal_encoder.rb +0 -109
- data/lib/rumale/preprocessing/polynomial_features.rb +0 -109
- data/lib/rumale/preprocessing/standard_scaler.rb +0 -71
- data/lib/rumale/probabilistic_output.rb +0 -114
- data/lib/rumale/tree/base_decision_tree.rb +0 -150
- data/lib/rumale/tree/decision_tree_classifier.rb +0 -150
- data/lib/rumale/tree/decision_tree_regressor.rb +0 -116
- data/lib/rumale/tree/extra_tree_classifier.rb +0 -107
- data/lib/rumale/tree/extra_tree_regressor.rb +0 -94
- data/lib/rumale/tree/gradient_tree_regressor.rb +0 -202
- data/lib/rumale/tree/node.rb +0 -39
- data/lib/rumale/utils.rb +0 -42
- data/lib/rumale/validation.rb +0 -128
- data/lib/rumale/values.rb +0 -13
|
@@ -1,56 +0,0 @@
|
|
|
1
|
-
# frozen_string_literal: true
|
|
2
|
-
|
|
3
|
-
require 'rumale/base/evaluator'
|
|
4
|
-
|
|
5
|
-
module Rumale
|
|
6
|
-
module EvaluationMeasure
|
|
7
|
-
# CalinskiHarabaszScore is a class that calculates the Calinski and Harabasz score.
|
|
8
|
-
#
|
|
9
|
-
# @example
|
|
10
|
-
# evaluator = Rumale::EvaluationMeasure::CalinskiHarabaszScore.new
|
|
11
|
-
# puts evaluator.score(x, predicted)
|
|
12
|
-
#
|
|
13
|
-
# *Reference*
|
|
14
|
-
# - Calinski, T., and Harabsz, J., "A dendrite method for cluster analysis," Communication in Statistics, Vol. 3 (1), pp. 1--27, 1972.
|
|
15
|
-
class CalinskiHarabaszScore
|
|
16
|
-
include Base::Evaluator
|
|
17
|
-
|
|
18
|
-
# Calculates the Calinski and Harabasz score.
|
|
19
|
-
#
|
|
20
|
-
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to be used for calculating score.
|
|
21
|
-
# @param y [Numo::Int32] (shape: [n_samples]) The predicted labels for each sample.
|
|
22
|
-
# @return [Float] The Calinski and Harabasz score.
|
|
23
|
-
def score(x, y)
|
|
24
|
-
x = check_convert_sample_array(x)
|
|
25
|
-
y = check_convert_label_array(y)
|
|
26
|
-
check_sample_label_size(x, y)
|
|
27
|
-
|
|
28
|
-
labels = y.to_a.uniq.sort
|
|
29
|
-
n_clusters = labels.size
|
|
30
|
-
n_dimensions = x.shape[1]
|
|
31
|
-
|
|
32
|
-
centroids = Numo::DFloat.zeros(n_clusters, n_dimensions)
|
|
33
|
-
|
|
34
|
-
within_group = 0.0
|
|
35
|
-
n_clusters.times do |n|
|
|
36
|
-
cls_samples = x[y.eq(labels[n]), true]
|
|
37
|
-
cls_centroid = cls_samples.mean(0)
|
|
38
|
-
centroids[n, true] = cls_centroid
|
|
39
|
-
within_group += ((cls_samples - cls_centroid)**2).sum
|
|
40
|
-
end
|
|
41
|
-
|
|
42
|
-
return 1.0 if within_group.zero?
|
|
43
|
-
|
|
44
|
-
mean_vec = x.mean(0)
|
|
45
|
-
between_group = 0.0
|
|
46
|
-
n_clusters.times do |n|
|
|
47
|
-
sz_cluster = y.eq(labels[n]).count
|
|
48
|
-
between_group += sz_cluster * ((centroids[n, true] - mean_vec)**2).sum
|
|
49
|
-
end
|
|
50
|
-
|
|
51
|
-
n_samples = x.shape[0]
|
|
52
|
-
(between_group / (n_clusters - 1)) / (within_group / (n_samples - n_clusters))
|
|
53
|
-
end
|
|
54
|
-
end
|
|
55
|
-
end
|
|
56
|
-
end
|
|
@@ -1,53 +0,0 @@
|
|
|
1
|
-
# frozen_string_literal: true
|
|
2
|
-
|
|
3
|
-
require 'rumale/base/evaluator'
|
|
4
|
-
require 'rumale/pairwise_metric'
|
|
5
|
-
|
|
6
|
-
module Rumale
|
|
7
|
-
module EvaluationMeasure
|
|
8
|
-
# DaviesBouldinScore is a class that calculates the Davies-Bouldin score.
|
|
9
|
-
#
|
|
10
|
-
# @example
|
|
11
|
-
# evaluator = Rumale::EvaluationMeasure::DaviesBouldinScore.new
|
|
12
|
-
# puts evaluator.score(x, predicted)
|
|
13
|
-
#
|
|
14
|
-
# *Reference*
|
|
15
|
-
# - Davies, D L., and Bouldin, D W., "A Cluster Separation Measure," IEEE Trans. Pattern Analysis and Machine Intelligence, Vol. PAMI-1, No. 2, pp. 224--227, 1979.
|
|
16
|
-
class DaviesBouldinScore
|
|
17
|
-
include Base::Evaluator
|
|
18
|
-
|
|
19
|
-
# Calculates the Davies-Bouldin score.
|
|
20
|
-
#
|
|
21
|
-
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to be used for calculating score.
|
|
22
|
-
# @param y [Numo::Int32] (shape: [n_samples]) The predicted labels for each sample.
|
|
23
|
-
# @return [Float] The Davies-Bouldin score.
|
|
24
|
-
def score(x, y)
|
|
25
|
-
x = check_convert_sample_array(x)
|
|
26
|
-
y = check_convert_label_array(y)
|
|
27
|
-
check_sample_label_size(x, y)
|
|
28
|
-
|
|
29
|
-
labels = y.to_a.uniq.sort
|
|
30
|
-
n_clusters = labels.size
|
|
31
|
-
n_dimensions = x.shape[1]
|
|
32
|
-
|
|
33
|
-
dist_cluster = Numo::DFloat.zeros(n_clusters)
|
|
34
|
-
centroids = Numo::DFloat.zeros(n_clusters, n_dimensions)
|
|
35
|
-
|
|
36
|
-
n_clusters.times do |n|
|
|
37
|
-
cls_samples = x[y.eq(labels[n]), true]
|
|
38
|
-
cls_centroid = cls_samples.mean(0)
|
|
39
|
-
centroids[n, true] = cls_centroid
|
|
40
|
-
dist_cluster[n] = Rumale::PairwiseMetric.euclidean_distance(cls_samples, cls_centroid.expand_dims(0)).mean
|
|
41
|
-
end
|
|
42
|
-
|
|
43
|
-
dist_centroid = Rumale::PairwiseMetric.euclidean_distance(centroids)
|
|
44
|
-
# p dist_cluster
|
|
45
|
-
# p dist_centroid
|
|
46
|
-
dist_centroid[dist_centroid.eq(0)] = Float::INFINITY
|
|
47
|
-
dist_mat = (dist_cluster.expand_dims(1) + dist_cluster) / dist_centroid
|
|
48
|
-
dist_mat[dist_mat.diag_indices] = -Float::INFINITY
|
|
49
|
-
dist_mat.max(0).mean
|
|
50
|
-
end
|
|
51
|
-
end
|
|
52
|
-
end
|
|
53
|
-
end
|
|
@@ -1,39 +0,0 @@
|
|
|
1
|
-
# frozen_string_literal: true
|
|
2
|
-
|
|
3
|
-
require 'rumale/base/evaluator'
|
|
4
|
-
|
|
5
|
-
module Rumale
|
|
6
|
-
module EvaluationMeasure
|
|
7
|
-
# ExplainedVarianceScore is a class that calculates the explained variance score.
|
|
8
|
-
#
|
|
9
|
-
# @example
|
|
10
|
-
# evaluator = Rumale::EvaluationMeasure::ExplainedVarianceScore.new
|
|
11
|
-
# puts evaluator.score(ground_truth, predicted)
|
|
12
|
-
class ExplainedVarianceScore
|
|
13
|
-
include Base::Evaluator
|
|
14
|
-
|
|
15
|
-
# Calculate explained variance score.
|
|
16
|
-
#
|
|
17
|
-
# @param y_true [Numo::DFloat] (shape: [n_samples, n_outputs]) Ground truth target values.
|
|
18
|
-
# @param y_pred [Numo::DFloat] (shape: [n_samples, n_outputs]) Estimated target values.
|
|
19
|
-
# @return [Float] Explained variance score.
|
|
20
|
-
def score(y_true, y_pred)
|
|
21
|
-
y_true = check_convert_tvalue_array(y_true)
|
|
22
|
-
y_pred = check_convert_tvalue_array(y_pred)
|
|
23
|
-
raise ArgumentError, 'Expect to have the same size both y_true and y_pred.' unless y_true.shape == y_pred.shape
|
|
24
|
-
|
|
25
|
-
diff = y_true - y_pred
|
|
26
|
-
numerator = ((diff - diff.mean(0))**2).mean(0)
|
|
27
|
-
denominator = ((y_true - y_true.mean(0))**2).mean(0)
|
|
28
|
-
|
|
29
|
-
n_outputs = y_true.shape[1]
|
|
30
|
-
if n_outputs.nil?
|
|
31
|
-
denominator.zero? ? 0 : 1.0 - numerator / denominator
|
|
32
|
-
else
|
|
33
|
-
valids = denominator.ne(0)
|
|
34
|
-
(1.0 - numerator[valids] / denominator[valids]).sum / n_outputs
|
|
35
|
-
end
|
|
36
|
-
end
|
|
37
|
-
end
|
|
38
|
-
end
|
|
39
|
-
end
|
|
@@ -1,50 +0,0 @@
|
|
|
1
|
-
# frozen_string_literal: true
|
|
2
|
-
|
|
3
|
-
require 'rumale/base/evaluator'
|
|
4
|
-
require 'rumale/evaluation_measure/precision_recall'
|
|
5
|
-
|
|
6
|
-
module Rumale
|
|
7
|
-
# This module consists of the classes for model evaluation.
|
|
8
|
-
module EvaluationMeasure
|
|
9
|
-
# FScore is a class that calculates the F1-score of the predicted labels.
|
|
10
|
-
#
|
|
11
|
-
# @example
|
|
12
|
-
# evaluator = Rumale::EvaluationMeasure::FScore.new
|
|
13
|
-
# puts evaluator.score(ground_truth, predicted)
|
|
14
|
-
class FScore
|
|
15
|
-
include Base::Evaluator
|
|
16
|
-
include EvaluationMeasure::PrecisionRecall
|
|
17
|
-
|
|
18
|
-
# Return the average type for calculation of F1-score.
|
|
19
|
-
# @return [String] ('binary', 'micro', 'macro')
|
|
20
|
-
attr_reader :average
|
|
21
|
-
|
|
22
|
-
# Create a new evaluation measure calculater for F1-score.
|
|
23
|
-
#
|
|
24
|
-
# @param average [String] The average type ('binary', 'micro', 'macro')
|
|
25
|
-
def initialize(average: 'binary')
|
|
26
|
-
check_params_string(average: average)
|
|
27
|
-
@average = average
|
|
28
|
-
end
|
|
29
|
-
|
|
30
|
-
# Calculate average F1-score
|
|
31
|
-
#
|
|
32
|
-
# @param y_true [Numo::Int32] (shape: [n_samples]) Ground truth labels.
|
|
33
|
-
# @param y_pred [Numo::Int32] (shape: [n_samples]) Predicted labels.
|
|
34
|
-
# @return [Float] Average F1-score
|
|
35
|
-
def score(y_true, y_pred)
|
|
36
|
-
y_true = check_convert_label_array(y_true)
|
|
37
|
-
y_pred = check_convert_label_array(y_pred)
|
|
38
|
-
|
|
39
|
-
case @average
|
|
40
|
-
when 'binary'
|
|
41
|
-
f_score_each_class(y_true, y_pred).last
|
|
42
|
-
when 'micro'
|
|
43
|
-
micro_average_f_score(y_true, y_pred)
|
|
44
|
-
when 'macro'
|
|
45
|
-
macro_average_f_score(y_true, y_pred)
|
|
46
|
-
end
|
|
47
|
-
end
|
|
48
|
-
end
|
|
49
|
-
end
|
|
50
|
-
end
|
|
@@ -1,147 +0,0 @@
|
|
|
1
|
-
# frozen_string_literal: true
|
|
2
|
-
|
|
3
|
-
require 'rumale/validation'
|
|
4
|
-
require 'rumale/evaluation_measure/accuracy'
|
|
5
|
-
require 'rumale/evaluation_measure/precision_recall'
|
|
6
|
-
|
|
7
|
-
module Rumale
|
|
8
|
-
module EvaluationMeasure
|
|
9
|
-
module_function
|
|
10
|
-
|
|
11
|
-
# Calculate confusion matrix for evaluating classification performance.
|
|
12
|
-
#
|
|
13
|
-
# @example
|
|
14
|
-
# y_true = Numo::Int32[2, 0, 2, 2, 0, 1]
|
|
15
|
-
# y_pred = Numo::Int32[0, 0, 2, 2, 0, 2]
|
|
16
|
-
# p Rumale::EvaluationMeasure.confusion_matrix(y_true, y_pred)
|
|
17
|
-
#
|
|
18
|
-
# # Numo::Int32#shape=[3,3]
|
|
19
|
-
# # [[2, 0, 0],
|
|
20
|
-
# # [0, 0, 1],
|
|
21
|
-
# # [1, 0, 2]]
|
|
22
|
-
#
|
|
23
|
-
# @param y_true [Numo::Int32] (shape: [n_samples]) The ground truth labels.
|
|
24
|
-
# @param y_pred [Numo::Int32] (shape: [n_samples]) The predicted labels.
|
|
25
|
-
# @return [Numo::Int32] (shape: [n_classes, n_classes]) The confusion matrix.
|
|
26
|
-
def confusion_matrix(y_true, y_pred)
|
|
27
|
-
y_true = Rumale::Validation.check_convert_label_array(y_true)
|
|
28
|
-
y_pred = Rumale::Validation.check_convert_label_array(y_pred)
|
|
29
|
-
|
|
30
|
-
labels = y_true.to_a.uniq.sort
|
|
31
|
-
n_labels = labels.size
|
|
32
|
-
|
|
33
|
-
conf_mat = Numo::Int32.zeros(n_labels, n_labels)
|
|
34
|
-
|
|
35
|
-
labels.each_with_index do |lbl_a, i|
|
|
36
|
-
y_p = y_pred[y_true.eq(lbl_a)]
|
|
37
|
-
labels.each_with_index do |lbl_b, j|
|
|
38
|
-
conf_mat[i, j] = y_p.eq(lbl_b).count
|
|
39
|
-
end
|
|
40
|
-
end
|
|
41
|
-
|
|
42
|
-
conf_mat
|
|
43
|
-
end
|
|
44
|
-
|
|
45
|
-
# rubocop:disable Metrics/MethodLength, Metrics/AbcSize
|
|
46
|
-
|
|
47
|
-
# Output a summary of classification performance for each class.
|
|
48
|
-
#
|
|
49
|
-
# @example
|
|
50
|
-
# y_true = Numo::Int32[0, 1, 1, 2, 2, 2, 0]
|
|
51
|
-
# y_pred = Numo::Int32[1, 1, 1, 0, 0, 2, 0]
|
|
52
|
-
# puts Rumale::EvaluationMeasure.classification_report(y_true, y_pred)
|
|
53
|
-
#
|
|
54
|
-
# # precision recall f1-score support
|
|
55
|
-
# #
|
|
56
|
-
# # 0 0.33 0.50 0.40 2
|
|
57
|
-
# # 1 0.67 1.00 0.80 2
|
|
58
|
-
# # 2 1.00 0.33 0.50 3
|
|
59
|
-
# #
|
|
60
|
-
# # accuracy 0.57 7
|
|
61
|
-
# # macro avg 0.67 0.61 0.57 7
|
|
62
|
-
# # weighted avg 0.71 0.57 0.56 7
|
|
63
|
-
#
|
|
64
|
-
# @param y_true [Numo::Int32] (shape: [n_samples]) The ground truth labels.
|
|
65
|
-
# @param y_pred [Numo::Int32] (shape: [n_samples]) The predicted labels.
|
|
66
|
-
# @param target_name [Nil/Array] The label names.
|
|
67
|
-
# @param output_hash [Boolean] The flag indicating whether to output with Ruby Hash.
|
|
68
|
-
# @return [String/Hash] The summary of classification performance.
|
|
69
|
-
# If output_hash is true, it returns the summary with Ruby Hash.
|
|
70
|
-
def classification_report(y_true, y_pred, target_name: nil, output_hash: false)
|
|
71
|
-
y_true = Rumale::Validation.check_convert_label_array(y_true)
|
|
72
|
-
y_pred = Rumale::Validation.check_convert_label_array(y_pred)
|
|
73
|
-
# calculate each evaluation measure.
|
|
74
|
-
classes = y_true.to_a.uniq.sort
|
|
75
|
-
supports = Numo::Int32.asarray(classes.map { |label| y_true.eq(label).count })
|
|
76
|
-
precisions = Rumale::EvaluationMeasure::PrecisionRecall.precision_each_class(y_true, y_pred)
|
|
77
|
-
recalls = Rumale::EvaluationMeasure::PrecisionRecall.recall_each_class(y_true, y_pred)
|
|
78
|
-
fscores = Rumale::EvaluationMeasure::PrecisionRecall.f_score_each_class(y_true, y_pred)
|
|
79
|
-
macro_precision = Rumale::EvaluationMeasure::PrecisionRecall.macro_average_precision(y_true, y_pred)
|
|
80
|
-
macro_recall = Rumale::EvaluationMeasure::PrecisionRecall.macro_average_recall(y_true, y_pred)
|
|
81
|
-
macro_fscore = Rumale::EvaluationMeasure::PrecisionRecall.macro_average_f_score(y_true, y_pred)
|
|
82
|
-
accuracy = Rumale::EvaluationMeasure::Accuracy.new.score(y_true, y_pred)
|
|
83
|
-
sum_supports = supports.sum
|
|
84
|
-
weights = Numo::DFloat.cast(supports) / sum_supports
|
|
85
|
-
weighted_precision = (Numo::DFloat.cast(precisions) * weights).sum
|
|
86
|
-
weighted_recall = (Numo::DFloat.cast(recalls) * weights).sum
|
|
87
|
-
weighted_fscore = (Numo::DFloat.cast(fscores) * weights).sum
|
|
88
|
-
# output reults.
|
|
89
|
-
target_name ||= classes
|
|
90
|
-
target_name.map!(&:to_s)
|
|
91
|
-
if output_hash
|
|
92
|
-
res = {}
|
|
93
|
-
target_name.each_with_index do |label, n|
|
|
94
|
-
res[label] = {
|
|
95
|
-
precision: precisions[n],
|
|
96
|
-
recall: recalls[n],
|
|
97
|
-
fscore: fscores[n],
|
|
98
|
-
support: supports[n]
|
|
99
|
-
}
|
|
100
|
-
end
|
|
101
|
-
res[:accuracy] = accuracy
|
|
102
|
-
res[:macro_avg] = {
|
|
103
|
-
precision: macro_precision,
|
|
104
|
-
recall: macro_recall,
|
|
105
|
-
fscore: macro_fscore,
|
|
106
|
-
support: sum_supports
|
|
107
|
-
}
|
|
108
|
-
res[:weighted_avg] = {
|
|
109
|
-
precision: weighted_precision,
|
|
110
|
-
recall: weighted_recall,
|
|
111
|
-
fscore: weighted_fscore,
|
|
112
|
-
support: sum_supports
|
|
113
|
-
}
|
|
114
|
-
else
|
|
115
|
-
width = [12, target_name.map(&:size).max].max # 12 is 'weighted avg'.size
|
|
116
|
-
res = +''
|
|
117
|
-
res << "#{' ' * width} precision recall f1-score support\n"
|
|
118
|
-
res << "\n"
|
|
119
|
-
target_name.each_with_index do |label, n|
|
|
120
|
-
label_str = format("%##{width}s", label)
|
|
121
|
-
precision_str = format('%#10s', format('%.2f', precisions[n]))
|
|
122
|
-
recall_str = format('%#10s', format('%.2f', recalls[n]))
|
|
123
|
-
fscore_str = format('%#10s', format('%.2f', fscores[n]))
|
|
124
|
-
supports_str = format('%#10s', supports[n])
|
|
125
|
-
res << "#{label_str} #{precision_str}#{recall_str}#{fscore_str}#{supports_str}\n"
|
|
126
|
-
end
|
|
127
|
-
res << "\n"
|
|
128
|
-
supports_str = format('%#10s', sum_supports)
|
|
129
|
-
accuracy_str = format('%#30s', format('%.2f', accuracy))
|
|
130
|
-
res << format("%##{width}s ", 'accuracy')
|
|
131
|
-
res << "#{accuracy_str}#{supports_str}\n"
|
|
132
|
-
precision_str = format('%#10s', format('%.2f', macro_precision))
|
|
133
|
-
recall_str = format('%#10s', format('%.2f', macro_recall))
|
|
134
|
-
fscore_str = format('%#10s', format('%.2f', macro_fscore))
|
|
135
|
-
res << format("%##{width}s ", 'macro avg')
|
|
136
|
-
res << "#{precision_str}#{recall_str}#{fscore_str}#{supports_str}\n"
|
|
137
|
-
precision_str = format('%#10s', format('%.2f', weighted_precision))
|
|
138
|
-
recall_str = format('%#10s', format('%.2f', weighted_recall))
|
|
139
|
-
fscore_str = format('%#10s', format('%.2f', weighted_fscore))
|
|
140
|
-
res << format("%##{width}s ", 'weighted avg')
|
|
141
|
-
res << "#{precision_str}#{recall_str}#{fscore_str}#{supports_str}\n"
|
|
142
|
-
end
|
|
143
|
-
res
|
|
144
|
-
end
|
|
145
|
-
# rubocop:enable Metrics/MethodLength, Metrics/AbcSize
|
|
146
|
-
end
|
|
147
|
-
end
|
|
@@ -1,45 +0,0 @@
|
|
|
1
|
-
# frozen_string_literal: true
|
|
2
|
-
|
|
3
|
-
require 'rumale/base/evaluator'
|
|
4
|
-
require 'rumale/preprocessing/label_binarizer'
|
|
5
|
-
|
|
6
|
-
module Rumale
|
|
7
|
-
module EvaluationMeasure
|
|
8
|
-
# LogLoss is a class that calculates the logarithmic loss of predicted class probability.
|
|
9
|
-
#
|
|
10
|
-
# @example
|
|
11
|
-
# evaluator = Rumale::EvaluationMeasure::LogLoss.new
|
|
12
|
-
# puts evaluator.score(ground_truth, predicted)
|
|
13
|
-
class LogLoss
|
|
14
|
-
include Base::Evaluator
|
|
15
|
-
|
|
16
|
-
# Calculate mean logarithmic loss.
|
|
17
|
-
# If both y_true and y_pred are array (both shapes are [n_samples]), this method calculates
|
|
18
|
-
# mean logarithmic loss for binary classification.
|
|
19
|
-
#
|
|
20
|
-
# @param y_true [Numo::Int32] (shape: [n_samples]) Ground truth labels.
|
|
21
|
-
# @param y_pred [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted class probability.
|
|
22
|
-
# @param eps [Float] A small value close to zero to avoid outputting infinity in logarithmic calcuation.
|
|
23
|
-
# @return [Float] mean logarithmic loss
|
|
24
|
-
def score(y_true, y_pred, eps = 1e-15)
|
|
25
|
-
y_true = check_convert_label_array(y_true)
|
|
26
|
-
y_pred = check_convert_tvalue_array(y_pred)
|
|
27
|
-
|
|
28
|
-
n_samples, n_classes = y_pred.shape
|
|
29
|
-
clipped_p = y_pred.clip(eps, 1 - eps)
|
|
30
|
-
|
|
31
|
-
log_loss = if n_classes.nil?
|
|
32
|
-
negative_label = y_true.to_a.uniq.min
|
|
33
|
-
bin_y_true = Numo::DFloat.cast(y_true.ne(negative_label))
|
|
34
|
-
-(bin_y_true * Numo::NMath.log(clipped_p) + (1 - bin_y_true) * Numo::NMath.log(1 - clipped_p))
|
|
35
|
-
else
|
|
36
|
-
encoder = Rumale::Preprocessing::LabelBinarizer.new
|
|
37
|
-
encoded_y_true = Numo::DFloat.cast(encoder.fit_transform(y_true))
|
|
38
|
-
clipped_p /= clipped_p.sum(1).expand_dims(1)
|
|
39
|
-
-(encoded_y_true * Numo::NMath.log(clipped_p)).sum(1)
|
|
40
|
-
end
|
|
41
|
-
log_loss.sum / n_samples
|
|
42
|
-
end
|
|
43
|
-
end
|
|
44
|
-
end
|
|
45
|
-
end
|
|
@@ -1,29 +0,0 @@
|
|
|
1
|
-
# frozen_string_literal: true
|
|
2
|
-
|
|
3
|
-
require 'rumale/base/evaluator'
|
|
4
|
-
|
|
5
|
-
module Rumale
|
|
6
|
-
module EvaluationMeasure
|
|
7
|
-
# MeanAbsoluteError is a class that calculates the mean absolute error.
|
|
8
|
-
#
|
|
9
|
-
# @example
|
|
10
|
-
# evaluator = Rumale::EvaluationMeasure::MeanAbsoluteError.new
|
|
11
|
-
# puts evaluator.score(ground_truth, predicted)
|
|
12
|
-
class MeanAbsoluteError
|
|
13
|
-
include Base::Evaluator
|
|
14
|
-
|
|
15
|
-
# Calculate mean absolute error.
|
|
16
|
-
#
|
|
17
|
-
# @param y_true [Numo::DFloat] (shape: [n_samples, n_outputs]) Ground truth target values.
|
|
18
|
-
# @param y_pred [Numo::DFloat] (shape: [n_samples, n_outputs]) Estimated target values.
|
|
19
|
-
# @return [Float] Mean absolute error
|
|
20
|
-
def score(y_true, y_pred)
|
|
21
|
-
y_true = check_convert_tvalue_array(y_true)
|
|
22
|
-
y_pred = check_convert_tvalue_array(y_pred)
|
|
23
|
-
raise ArgumentError, 'Expect to have the same size both y_true and y_pred.' unless y_true.shape == y_pred.shape
|
|
24
|
-
|
|
25
|
-
(y_true - y_pred).abs.mean
|
|
26
|
-
end
|
|
27
|
-
end
|
|
28
|
-
end
|
|
29
|
-
end
|
|
@@ -1,29 +0,0 @@
|
|
|
1
|
-
# frozen_string_literal: true
|
|
2
|
-
|
|
3
|
-
require 'rumale/base/evaluator'
|
|
4
|
-
|
|
5
|
-
module Rumale
|
|
6
|
-
module EvaluationMeasure
|
|
7
|
-
# MeanSquaredError is a class that calculates the mean squared error.
|
|
8
|
-
#
|
|
9
|
-
# @example
|
|
10
|
-
# evaluator = Rumale::EvaluationMeasure::MeanSquaredError.new
|
|
11
|
-
# puts evaluator.score(ground_truth, predicted)
|
|
12
|
-
class MeanSquaredError
|
|
13
|
-
include Base::Evaluator
|
|
14
|
-
|
|
15
|
-
# Calculate mean squared error.
|
|
16
|
-
#
|
|
17
|
-
# @param y_true [Numo::DFloat] (shape: [n_samples, n_outputs]) Ground truth target values.
|
|
18
|
-
# @param y_pred [Numo::DFloat] (shape: [n_samples, n_outputs]) Estimated target values.
|
|
19
|
-
# @return [Float] Mean squared error
|
|
20
|
-
def score(y_true, y_pred)
|
|
21
|
-
y_true = check_convert_tvalue_array(y_true)
|
|
22
|
-
y_pred = check_convert_tvalue_array(y_pred)
|
|
23
|
-
raise ArgumentError, 'Expect to have the same size both y_true and y_pred.' unless y_true.shape == y_pred.shape
|
|
24
|
-
|
|
25
|
-
((y_true - y_pred)**2).mean
|
|
26
|
-
end
|
|
27
|
-
end
|
|
28
|
-
end
|
|
29
|
-
end
|
|
@@ -1,29 +0,0 @@
|
|
|
1
|
-
# frozen_string_literal: true
|
|
2
|
-
|
|
3
|
-
require 'rumale/base/evaluator'
|
|
4
|
-
|
|
5
|
-
module Rumale
|
|
6
|
-
module EvaluationMeasure
|
|
7
|
-
# MeanSquaredLogError is a class that calculates the mean squared logarithmic error.
|
|
8
|
-
#
|
|
9
|
-
# @example
|
|
10
|
-
# evaluator = Rumale::EvaluationMeasure::MeanSquaredError.new
|
|
11
|
-
# puts evaluator.score(ground_truth, predicted)
|
|
12
|
-
class MeanSquaredLogError
|
|
13
|
-
include Base::Evaluator
|
|
14
|
-
|
|
15
|
-
# Calculate mean squared logarithmic error.
|
|
16
|
-
#
|
|
17
|
-
# @param y_true [Numo::DFloat] (shape: [n_samples, n_outputs]) Ground truth target values.
|
|
18
|
-
# @param y_pred [Numo::DFloat] (shape: [n_samples, n_outputs]) Estimated target values.
|
|
19
|
-
# @return [Float] Mean squared logarithmic error.
|
|
20
|
-
def score(y_true, y_pred)
|
|
21
|
-
y_true = check_convert_tvalue_array(y_true)
|
|
22
|
-
y_pred = check_convert_tvalue_array(y_pred)
|
|
23
|
-
raise ArgumentError, 'Expect to have the same size both y_true and y_pred.' unless y_true.shape == y_pred.shape
|
|
24
|
-
|
|
25
|
-
((Numo::NMath.log(y_true + 1) - Numo::NMath.log(y_pred + 1))**2).mean
|
|
26
|
-
end
|
|
27
|
-
end
|
|
28
|
-
end
|
|
29
|
-
end
|
|
@@ -1,30 +0,0 @@
|
|
|
1
|
-
# frozen_string_literal: true
|
|
2
|
-
|
|
3
|
-
require 'rumale/base/evaluator'
|
|
4
|
-
|
|
5
|
-
module Rumale
|
|
6
|
-
module EvaluationMeasure
|
|
7
|
-
# MedianAbsoluteError is a class that calculates the median absolute error.
|
|
8
|
-
#
|
|
9
|
-
# @example
|
|
10
|
-
# evaluator = Rumale::EvaluationMeasure::MedianAbsoluteError.new
|
|
11
|
-
# puts evaluator.score(ground_truth, predicted)
|
|
12
|
-
class MedianAbsoluteError
|
|
13
|
-
include Base::Evaluator
|
|
14
|
-
|
|
15
|
-
# Calculate median absolute error.
|
|
16
|
-
#
|
|
17
|
-
# @param y_true [Numo::DFloat] (shape: [n_samples]) Ground truth target values.
|
|
18
|
-
# @param y_pred [Numo::DFloat] (shape: [n_samples]) Estimated target values.
|
|
19
|
-
# @return [Float] Median absolute error.
|
|
20
|
-
def score(y_true, y_pred)
|
|
21
|
-
y_true = check_convert_tvalue_array(y_true)
|
|
22
|
-
y_pred = check_convert_tvalue_array(y_pred)
|
|
23
|
-
raise ArgumentError, 'Expect to have the same size both y_true and y_pred.' unless y_true.shape == y_pred.shape
|
|
24
|
-
raise ArgumentError, 'Expect target values to be 1-D arrray' if [y_true.shape.size, y_pred.shape.size].max > 1
|
|
25
|
-
|
|
26
|
-
(y_true - y_pred).abs.median
|
|
27
|
-
end
|
|
28
|
-
end
|
|
29
|
-
end
|
|
30
|
-
end
|
|
@@ -1,49 +0,0 @@
|
|
|
1
|
-
# frozen_string_literal: true
|
|
2
|
-
|
|
3
|
-
require 'rumale/base/evaluator'
|
|
4
|
-
|
|
5
|
-
module Rumale
|
|
6
|
-
module EvaluationMeasure
|
|
7
|
-
# MutualInformation is a class that calculates the mutual information.
|
|
8
|
-
#
|
|
9
|
-
# @example
|
|
10
|
-
# evaluator = Rumale::EvaluationMeasure::MutualInformation.new
|
|
11
|
-
# puts evaluator.score(ground_truth, predicted)
|
|
12
|
-
#
|
|
13
|
-
# *Reference*
|
|
14
|
-
# - Vinh, N X., Epps, J., and Bailey, J., "Information Theoretic Measures for Clusterings Comparison: Variants, Properties, Normalization and Correction for Chance," J. Machine Learning Research, vol. 11, pp. 2837--1854, 2010.
|
|
15
|
-
class MutualInformation
|
|
16
|
-
include Base::Evaluator
|
|
17
|
-
|
|
18
|
-
# Calculate mutual information
|
|
19
|
-
#
|
|
20
|
-
# @param y_true [Numo::Int32] (shape: [n_samples]) Ground truth labels.
|
|
21
|
-
# @param y_pred [Numo::Int32] (shape: [n_samples]) Predicted cluster labels.
|
|
22
|
-
# @return [Float] Mutual information.
|
|
23
|
-
def score(y_true, y_pred)
|
|
24
|
-
y_true = check_convert_label_array(y_true)
|
|
25
|
-
y_pred = check_convert_label_array(y_pred)
|
|
26
|
-
# initiazlie some variables.
|
|
27
|
-
mutual_information = 0.0
|
|
28
|
-
n_samples = y_pred.size
|
|
29
|
-
class_ids = y_true.to_a.uniq
|
|
30
|
-
cluster_ids = y_pred.to_a.uniq
|
|
31
|
-
# calculate mutual information.
|
|
32
|
-
cluster_ids.map do |k|
|
|
33
|
-
pr_sample_ids = y_pred.eq(k).where.to_a
|
|
34
|
-
n_pr_samples = pr_sample_ids.size
|
|
35
|
-
class_ids.map do |j|
|
|
36
|
-
tr_sample_ids = y_true.eq(j).where.to_a
|
|
37
|
-
n_tr_samples = tr_sample_ids.size
|
|
38
|
-
n_intr_samples = (pr_sample_ids & tr_sample_ids).size
|
|
39
|
-
if n_intr_samples.positive?
|
|
40
|
-
mutual_information +=
|
|
41
|
-
n_intr_samples.fdiv(n_samples) * Math.log((n_samples * n_intr_samples).fdiv(n_pr_samples * n_tr_samples))
|
|
42
|
-
end
|
|
43
|
-
end
|
|
44
|
-
end
|
|
45
|
-
mutual_information
|
|
46
|
-
end
|
|
47
|
-
end
|
|
48
|
-
end
|
|
49
|
-
end
|
|
@@ -1,53 +0,0 @@
|
|
|
1
|
-
# frozen_string_literal: true
|
|
2
|
-
|
|
3
|
-
require 'rumale/base/evaluator'
|
|
4
|
-
require 'rumale/evaluation_measure/mutual_information'
|
|
5
|
-
|
|
6
|
-
module Rumale
|
|
7
|
-
module EvaluationMeasure
|
|
8
|
-
# NormalizedMutualInformation is a class that calculates the normalized mutual information.
|
|
9
|
-
#
|
|
10
|
-
# @example
|
|
11
|
-
# evaluator = Rumale::EvaluationMeasure::NormalizedMutualInformation.new
|
|
12
|
-
# puts evaluator.score(ground_truth, predicted)
|
|
13
|
-
#
|
|
14
|
-
# *Reference*
|
|
15
|
-
# - Manning, C D., Raghavan, P., and Schutze, H., "Introduction to Information Retrieval," Cambridge University Press., 2008.
|
|
16
|
-
# - Vinh, N X., Epps, J., and Bailey, J., "Information Theoretic Measures for Clusterings Comparison: Variants, Properties, Normalization and Correction for Chance," J. Machine Learning Research, vol. 11, pp. 2837--1854, 2010.
|
|
17
|
-
class NormalizedMutualInformation
|
|
18
|
-
include Base::Evaluator
|
|
19
|
-
|
|
20
|
-
# Calculate noramlzied mutual information
|
|
21
|
-
#
|
|
22
|
-
# @param y_true [Numo::Int32] (shape: [n_samples]) Ground truth labels.
|
|
23
|
-
# @param y_pred [Numo::Int32] (shape: [n_samples]) Predicted cluster labels.
|
|
24
|
-
# @return [Float] Normalized mutual information
|
|
25
|
-
def score(y_true, y_pred)
|
|
26
|
-
y_true = check_convert_label_array(y_true)
|
|
27
|
-
y_pred = check_convert_label_array(y_pred)
|
|
28
|
-
# calculate entropies.
|
|
29
|
-
class_entropy = entropy(y_true)
|
|
30
|
-
return 0.0 if class_entropy.zero?
|
|
31
|
-
|
|
32
|
-
cluster_entropy = entropy(y_pred)
|
|
33
|
-
return 0.0 if cluster_entropy.zero?
|
|
34
|
-
|
|
35
|
-
# calculate mutual information.
|
|
36
|
-
mi = MutualInformation.new
|
|
37
|
-
mi.score(y_true, y_pred) / Math.sqrt(class_entropy * cluster_entropy)
|
|
38
|
-
end
|
|
39
|
-
|
|
40
|
-
private
|
|
41
|
-
|
|
42
|
-
def entropy(y)
|
|
43
|
-
n_samples = y.size
|
|
44
|
-
indices = y.to_a.uniq
|
|
45
|
-
sum_log = indices.map do |k|
|
|
46
|
-
ratio = y.eq(k).count.fdiv(n_samples)
|
|
47
|
-
ratio * Math.log(ratio)
|
|
48
|
-
end.reduce(:+)
|
|
49
|
-
-sum_log
|
|
50
|
-
end
|
|
51
|
-
end
|
|
52
|
-
end
|
|
53
|
-
end
|
|
@@ -1,50 +0,0 @@
|
|
|
1
|
-
# frozen_string_literal: true
|
|
2
|
-
|
|
3
|
-
require 'rumale/base/evaluator'
|
|
4
|
-
require 'rumale/evaluation_measure/precision_recall'
|
|
5
|
-
|
|
6
|
-
module Rumale
|
|
7
|
-
# This module consists of the classes for model evaluation.
|
|
8
|
-
module EvaluationMeasure
|
|
9
|
-
# Precision is a class that calculates the preicision of the predicted labels.
|
|
10
|
-
#
|
|
11
|
-
# @example
|
|
12
|
-
# evaluator = Rumale::EvaluationMeasure::Precision.new
|
|
13
|
-
# puts evaluator.score(ground_truth, predicted)
|
|
14
|
-
class Precision
|
|
15
|
-
include Base::Evaluator
|
|
16
|
-
include EvaluationMeasure::PrecisionRecall
|
|
17
|
-
|
|
18
|
-
# Return the average type for calculation of precision.
|
|
19
|
-
# @return [String] ('binary', 'micro', 'macro')
|
|
20
|
-
attr_reader :average
|
|
21
|
-
|
|
22
|
-
# Create a new evaluation measure calculater for precision score.
|
|
23
|
-
#
|
|
24
|
-
# @param average [String] The average type ('binary', 'micro', 'macro')
|
|
25
|
-
def initialize(average: 'binary')
|
|
26
|
-
check_params_string(average: average)
|
|
27
|
-
@average = average
|
|
28
|
-
end
|
|
29
|
-
|
|
30
|
-
# Calculate average precision.
|
|
31
|
-
#
|
|
32
|
-
# @param y_true [Numo::Int32] (shape: [n_samples]) Ground truth labels.
|
|
33
|
-
# @param y_pred [Numo::Int32] (shape: [n_samples]) Predicted labels.
|
|
34
|
-
# @return [Float] Average precision
|
|
35
|
-
def score(y_true, y_pred)
|
|
36
|
-
y_true = check_convert_label_array(y_true)
|
|
37
|
-
y_pred = check_convert_label_array(y_pred)
|
|
38
|
-
|
|
39
|
-
case @average
|
|
40
|
-
when 'binary'
|
|
41
|
-
precision_each_class(y_true, y_pred).last
|
|
42
|
-
when 'micro'
|
|
43
|
-
micro_average_precision(y_true, y_pred)
|
|
44
|
-
when 'macro'
|
|
45
|
-
macro_average_precision(y_true, y_pred)
|
|
46
|
-
end
|
|
47
|
-
end
|
|
48
|
-
end
|
|
49
|
-
end
|
|
50
|
-
end
|