rumale 0.8.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (85) hide show
  1. checksums.yaml +7 -0
  2. data/.coveralls.yml +1 -0
  3. data/.gitignore +20 -0
  4. data/.rspec +3 -0
  5. data/.rubocop.yml +47 -0
  6. data/.rubocop_todo.yml +58 -0
  7. data/.travis.yml +13 -0
  8. data/CHANGELOG.md +2 -0
  9. data/CODE_OF_CONDUCT.md +74 -0
  10. data/Gemfile +4 -0
  11. data/LICENSE.txt +23 -0
  12. data/README.md +175 -0
  13. data/Rakefile +6 -0
  14. data/bin/console +14 -0
  15. data/bin/setup +8 -0
  16. data/lib/rumale.rb +70 -0
  17. data/lib/rumale/base/base_estimator.rb +13 -0
  18. data/lib/rumale/base/classifier.rb +36 -0
  19. data/lib/rumale/base/cluster_analyzer.rb +31 -0
  20. data/lib/rumale/base/evaluator.rb +17 -0
  21. data/lib/rumale/base/regressor.rb +36 -0
  22. data/lib/rumale/base/splitter.rb +21 -0
  23. data/lib/rumale/base/transformer.rb +22 -0
  24. data/lib/rumale/clustering/dbscan.rb +125 -0
  25. data/lib/rumale/clustering/k_means.rb +138 -0
  26. data/lib/rumale/dataset.rb +110 -0
  27. data/lib/rumale/decomposition/nmf.rb +141 -0
  28. data/lib/rumale/decomposition/pca.rb +148 -0
  29. data/lib/rumale/ensemble/ada_boost_classifier.rb +196 -0
  30. data/lib/rumale/ensemble/ada_boost_regressor.rb +178 -0
  31. data/lib/rumale/ensemble/random_forest_classifier.rb +180 -0
  32. data/lib/rumale/ensemble/random_forest_regressor.rb +141 -0
  33. data/lib/rumale/evaluation_measure/accuracy.rb +29 -0
  34. data/lib/rumale/evaluation_measure/f_score.rb +50 -0
  35. data/lib/rumale/evaluation_measure/log_loss.rb +45 -0
  36. data/lib/rumale/evaluation_measure/mean_absolute_error.rb +29 -0
  37. data/lib/rumale/evaluation_measure/mean_squared_error.rb +29 -0
  38. data/lib/rumale/evaluation_measure/normalized_mutual_information.rb +62 -0
  39. data/lib/rumale/evaluation_measure/precision.rb +50 -0
  40. data/lib/rumale/evaluation_measure/precision_recall.rb +91 -0
  41. data/lib/rumale/evaluation_measure/purity.rb +40 -0
  42. data/lib/rumale/evaluation_measure/r2_score.rb +43 -0
  43. data/lib/rumale/evaluation_measure/recall.rb +50 -0
  44. data/lib/rumale/kernel_approximation/rbf.rb +121 -0
  45. data/lib/rumale/kernel_machine/kernel_svc.rb +193 -0
  46. data/lib/rumale/linear_model/base_linear_model.rb +89 -0
  47. data/lib/rumale/linear_model/lasso.rb +136 -0
  48. data/lib/rumale/linear_model/linear_regression.rb +110 -0
  49. data/lib/rumale/linear_model/logistic_regression.rb +159 -0
  50. data/lib/rumale/linear_model/ridge.rb +110 -0
  51. data/lib/rumale/linear_model/svc.rb +183 -0
  52. data/lib/rumale/linear_model/svr.rb +122 -0
  53. data/lib/rumale/model_selection/cross_validation.rb +123 -0
  54. data/lib/rumale/model_selection/grid_search_cv.rb +247 -0
  55. data/lib/rumale/model_selection/k_fold.rb +76 -0
  56. data/lib/rumale/model_selection/stratified_k_fold.rb +94 -0
  57. data/lib/rumale/multiclass/one_vs_rest_classifier.rb +100 -0
  58. data/lib/rumale/naive_bayes/naive_bayes.rb +315 -0
  59. data/lib/rumale/nearest_neighbors/k_neighbors_classifier.rb +111 -0
  60. data/lib/rumale/nearest_neighbors/k_neighbors_regressor.rb +93 -0
  61. data/lib/rumale/optimizer/nadam.rb +90 -0
  62. data/lib/rumale/optimizer/rmsprop.rb +69 -0
  63. data/lib/rumale/optimizer/sgd.rb +65 -0
  64. data/lib/rumale/optimizer/yellow_fin.rb +144 -0
  65. data/lib/rumale/pairwise_metric.rb +91 -0
  66. data/lib/rumale/pipeline/pipeline.rb +197 -0
  67. data/lib/rumale/polynomial_model/base_factorization_machine.rb +99 -0
  68. data/lib/rumale/polynomial_model/factorization_machine_classifier.rb +197 -0
  69. data/lib/rumale/polynomial_model/factorization_machine_regressor.rb +131 -0
  70. data/lib/rumale/preprocessing/l2_normalizer.rb +62 -0
  71. data/lib/rumale/preprocessing/label_encoder.rb +94 -0
  72. data/lib/rumale/preprocessing/min_max_scaler.rb +92 -0
  73. data/lib/rumale/preprocessing/one_hot_encoder.rb +98 -0
  74. data/lib/rumale/preprocessing/standard_scaler.rb +86 -0
  75. data/lib/rumale/probabilistic_output.rb +112 -0
  76. data/lib/rumale/tree/base_decision_tree.rb +153 -0
  77. data/lib/rumale/tree/decision_tree_classifier.rb +163 -0
  78. data/lib/rumale/tree/decision_tree_regressor.rb +135 -0
  79. data/lib/rumale/tree/node.rb +70 -0
  80. data/lib/rumale/utils.rb +37 -0
  81. data/lib/rumale/validation.rb +79 -0
  82. data/lib/rumale/values.rb +13 -0
  83. data/lib/rumale/version.rb +6 -0
  84. data/rumale.gemspec +41 -0
  85. metadata +204 -0
@@ -0,0 +1,29 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/evaluator'
4
+
5
+ module Rumale
6
+ # This module consists of the classes for model evaluation.
7
+ module EvaluationMeasure
8
+ # Accuracy is a class that calculates the accuracy of classifier from the predicted labels.
9
+ #
10
+ # @example
11
+ # evaluator = Rumale::EvaluationMeasure::Accuracy.new
12
+ # puts evaluator.score(ground_truth, predicted)
13
+ class Accuracy
14
+ include Base::Evaluator
15
+
16
+ # Calculate mean accuracy.
17
+ #
18
+ # @param y_true [Numo::Int32] (shape: [n_samples]) Ground truth labels.
19
+ # @param y_pred [Numo::Int32] (shape: [n_samples]) Predicted labels.
20
+ # @return [Float] Mean accuracy
21
+ def score(y_true, y_pred)
22
+ check_label_array(y_true)
23
+ check_label_array(y_pred)
24
+
25
+ (y_true.to_a.map.with_index { |label, n| label == y_pred[n] ? 1 : 0 }).inject(:+) / y_true.size.to_f
26
+ end
27
+ end
28
+ end
29
+ end
@@ -0,0 +1,50 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/evaluator'
4
+ require 'rumale/evaluation_measure/precision_recall'
5
+
6
+ module Rumale
7
+ # This module consists of the classes for model evaluation.
8
+ module EvaluationMeasure
9
+ # FScore is a class that calculates the F1-score of the predicted labels.
10
+ #
11
+ # @example
12
+ # evaluator = Rumale::EvaluationMeasure::FScore.new
13
+ # puts evaluator.score(ground_truth, predicted)
14
+ class FScore
15
+ include Base::Evaluator
16
+ include EvaluationMeasure::PrecisionRecall
17
+
18
+ # Return the average type for calculation of F1-score.
19
+ # @return [String] ('binary', 'micro', 'macro')
20
+ attr_reader :average
21
+
22
+ # Create a new evaluation measure calculater for F1-score.
23
+ #
24
+ # @param average [String] The average type ('binary', 'micro', 'macro')
25
+ def initialize(average: 'binary')
26
+ check_params_string(average: average)
27
+ @average = average
28
+ end
29
+
30
+ # Calculate average F1-score
31
+ #
32
+ # @param y_true [Numo::Int32] (shape: [n_samples]) Ground truth labels.
33
+ # @param y_pred [Numo::Int32] (shape: [n_samples]) Predicted labels.
34
+ # @return [Float] Average F1-score
35
+ def score(y_true, y_pred)
36
+ check_label_array(y_true)
37
+ check_label_array(y_pred)
38
+
39
+ case @average
40
+ when 'binary'
41
+ f_score_each_class(y_true, y_pred).last
42
+ when 'micro'
43
+ micro_average_f_score(y_true, y_pred)
44
+ when 'macro'
45
+ macro_average_f_score(y_true, y_pred)
46
+ end
47
+ end
48
+ end
49
+ end
50
+ end
@@ -0,0 +1,45 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/evaluator'
4
+ require 'rumale/preprocessing/one_hot_encoder'
5
+
6
+ module Rumale
7
+ module EvaluationMeasure
8
+ # LogLoss is a class that calculates the logarithmic loss of predicted class probability.
9
+ #
10
+ # @example
11
+ # evaluator = Rumale::EvaluationMeasure::LogLoss.new
12
+ # puts evaluator.score(ground_truth, predicted)
13
+ class LogLoss
14
+ include Base::Evaluator
15
+
16
+ # Calculate mean logarithmic loss.
17
+ # If both y_true and y_pred are array (both shapes are [n_samples]), this method calculates
18
+ # mean logarithmic loss for binary classification.
19
+ #
20
+ # @param y_true [Numo::Int32] (shape: [n_samples]) Ground truth labels.
21
+ # @param y_pred [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted class probability.
22
+ # @param eps [Float] A small value close to zero to avoid outputting infinity in logarithmic calcuation.
23
+ # @return [Float] mean logarithmic loss
24
+ def score(y_true, y_pred, eps = 1e-15)
25
+ check_params_type(Numo::Int32, y_true: y_true)
26
+ check_params_type(Numo::DFloat, y_pred: y_pred)
27
+
28
+ n_samples, n_classes = y_pred.shape
29
+ clipped_p = y_pred.clip(eps, 1 - eps)
30
+
31
+ log_loss = if n_classes.nil?
32
+ negative_label = y_true.to_a.uniq.min
33
+ bin_y_true = Numo::DFloat.cast(y_true.ne(negative_label))
34
+ -(bin_y_true * Numo::NMath.log(clipped_p) + (1 - bin_y_true) * Numo::NMath.log(1 - clipped_p))
35
+ else
36
+ encoder = Rumale::Preprocessing::OneHotEncoder.new
37
+ encoded_y_true = encoder.fit_transform(y_true)
38
+ clipped_p /= clipped_p.sum(1).expand_dims(1)
39
+ -(encoded_y_true * Numo::NMath.log(clipped_p)).sum(1)
40
+ end
41
+ log_loss.sum / n_samples
42
+ end
43
+ end
44
+ end
45
+ end
@@ -0,0 +1,29 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/evaluator'
4
+
5
+ module Rumale
6
+ module EvaluationMeasure
7
+ # MeanAbsoluteError is a class that calculates the mean absolute error.
8
+ #
9
+ # @example
10
+ # evaluator = Rumale::EvaluationMeasure::MeanAbsoluteError.new
11
+ # puts evaluator.score(ground_truth, predicted)
12
+ class MeanAbsoluteError
13
+ include Base::Evaluator
14
+
15
+ # Calculate mean absolute error.
16
+ #
17
+ # @param y_true [Numo::DFloat] (shape: [n_samples, n_outputs]) Ground truth target values.
18
+ # @param y_pred [Numo::DFloat] (shape: [n_samples, n_outputs]) Estimated target values.
19
+ # @return [Float] Mean absolute error
20
+ def score(y_true, y_pred)
21
+ check_tvalue_array(y_true)
22
+ check_tvalue_array(y_pred)
23
+ raise ArgumentError, 'Expect to have the same size both y_true and y_pred.' unless y_true.shape == y_pred.shape
24
+
25
+ (y_true - y_pred).abs.mean
26
+ end
27
+ end
28
+ end
29
+ end
@@ -0,0 +1,29 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/evaluator'
4
+
5
+ module Rumale
6
+ module EvaluationMeasure
7
+ # MeanSquaredError is a class that calculates the mean squared error.
8
+ #
9
+ # @example
10
+ # evaluator = Rumale::EvaluationMeasure::MeanSquaredError.new
11
+ # puts evaluator.score(ground_truth, predicted)
12
+ class MeanSquaredError
13
+ include Base::Evaluator
14
+
15
+ # Calculate mean squared error.
16
+ #
17
+ # @param y_true [Numo::DFloat] (shape: [n_samples, n_outputs]) Ground truth target values.
18
+ # @param y_pred [Numo::DFloat] (shape: [n_samples, n_outputs]) Estimated target values.
19
+ # @return [Float] Mean squared error
20
+ def score(y_true, y_pred)
21
+ check_tvalue_array(y_true)
22
+ check_tvalue_array(y_pred)
23
+ raise ArgumentError, 'Expect to have the same size both y_true and y_pred.' unless y_true.shape == y_pred.shape
24
+
25
+ ((y_true - y_pred)**2).mean
26
+ end
27
+ end
28
+ end
29
+ end
@@ -0,0 +1,62 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/evaluator'
4
+
5
+ module Rumale
6
+ module EvaluationMeasure
7
+ # NormalizedMutualInformation is a class that calculates the normalized mutual information of cluatering results.
8
+ #
9
+ # @example
10
+ # evaluator = Rumale::EvaluationMeasure::NormalizedMutualInformation.new
11
+ # puts evaluator.score(ground_truth, predicted)
12
+ #
13
+ # *Reference*
14
+ # - C D. Manning, P. Raghavan, and H. Schutze, "Introduction to Information Retrieval," Cambridge University Press., 2008.
15
+ # - N X. Vinh, J. Epps, and J. Bailey, "Information Theoretic Measures for Clusterings Comparison: Variants, Properties, Normalization and Correction for Chance," J. Machine Learning Research, vol. 11, pp. 2837--1854, 2010.
16
+ class NormalizedMutualInformation
17
+ include Base::Evaluator
18
+
19
+ # Calculate noramlzied mutual information
20
+ #
21
+ # @param y_true [Numo::Int32] (shape: [n_samples]) Ground truth labels.
22
+ # @param y_pred [Numo::Int32] (shape: [n_samples]) Predicted cluster labels.
23
+ # @return [Float] Normalized mutual information
24
+ def score(y_true, y_pred)
25
+ check_label_array(y_true)
26
+ check_label_array(y_pred)
27
+ # initiazlie some variables.
28
+ mutual_information = 0.0
29
+ n_samples = y_pred.size
30
+ class_ids = y_true.to_a.uniq
31
+ cluster_ids = y_pred.to_a.uniq
32
+ # calculate entropy.
33
+ class_entropy = -1.0 * class_ids.map do |k|
34
+ ratio = y_true.eq(k).count.fdiv(n_samples)
35
+ ratio * Math.log(ratio)
36
+ end.reduce(:+)
37
+ return 0.0 if class_entropy.zero?
38
+ cluster_entropy = -1.0 * cluster_ids.map do |k|
39
+ ratio = y_pred.eq(k).count.fdiv(n_samples)
40
+ ratio * Math.log(ratio)
41
+ end.reduce(:+)
42
+ return 0.0 if cluster_entropy.zero?
43
+ # calculate mutual information.
44
+ cluster_ids.map do |k|
45
+ pr_sample_ids = y_pred.eq(k).where.to_a
46
+ n_pr_samples = pr_sample_ids.size
47
+ class_ids.map do |j|
48
+ tr_sample_ids = y_true.eq(j).where.to_a
49
+ n_tr_samples = tr_sample_ids.size
50
+ n_intr_samples = (pr_sample_ids & tr_sample_ids).size
51
+ if n_intr_samples.positive?
52
+ mutual_information +=
53
+ n_intr_samples.fdiv(n_samples) * Math.log((n_samples * n_intr_samples).fdiv(n_pr_samples * n_tr_samples))
54
+ end
55
+ end
56
+ end
57
+ # return normalized mutual information.
58
+ mutual_information / Math.sqrt(class_entropy * cluster_entropy)
59
+ end
60
+ end
61
+ end
62
+ end
@@ -0,0 +1,50 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/evaluator'
4
+ require 'rumale/evaluation_measure/precision_recall'
5
+
6
+ module Rumale
7
+ # This module consists of the classes for model evaluation.
8
+ module EvaluationMeasure
9
+ # Precision is a class that calculates the preicision of the predicted labels.
10
+ #
11
+ # @example
12
+ # evaluator = Rumale::EvaluationMeasure::Precision.new
13
+ # puts evaluator.score(ground_truth, predicted)
14
+ class Precision
15
+ include Base::Evaluator
16
+ include EvaluationMeasure::PrecisionRecall
17
+
18
+ # Return the average type for calculation of precision.
19
+ # @return [String] ('binary', 'micro', 'macro')
20
+ attr_reader :average
21
+
22
+ # Create a new evaluation measure calculater for precision score.
23
+ #
24
+ # @param average [String] The average type ('binary', 'micro', 'macro')
25
+ def initialize(average: 'binary')
26
+ check_params_string(average: average)
27
+ @average = average
28
+ end
29
+
30
+ # Calculate average precision.
31
+ #
32
+ # @param y_true [Numo::Int32] (shape: [n_samples]) Ground truth labels.
33
+ # @param y_pred [Numo::Int32] (shape: [n_samples]) Predicted labels.
34
+ # @return [Float] Average precision
35
+ def score(y_true, y_pred)
36
+ check_label_array(y_true)
37
+ check_label_array(y_pred)
38
+
39
+ case @average
40
+ when 'binary'
41
+ precision_each_class(y_true, y_pred).last
42
+ when 'micro'
43
+ micro_average_precision(y_true, y_pred)
44
+ when 'macro'
45
+ macro_average_precision(y_true, y_pred)
46
+ end
47
+ end
48
+ end
49
+ end
50
+ end
@@ -0,0 +1,91 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/evaluator'
4
+
5
+ module Rumale
6
+ # This module consists of the classes for model evaluation.
7
+ module EvaluationMeasure
8
+ # @!visibility private
9
+ module PrecisionRecall
10
+ module_function
11
+
12
+ # @!visibility private
13
+ def precision_each_class(y_true, y_pred)
14
+ y_true.sort.to_a.uniq.map do |label|
15
+ target_positions = y_pred.eq(label)
16
+ next 0.0 if y_pred[target_positions].empty?
17
+ n_true_positives = Numo::Int32.cast(y_true[target_positions].eq(y_pred[target_positions])).sum.to_f
18
+ n_false_positives = Numo::Int32.cast(y_true[target_positions].ne(y_pred[target_positions])).sum.to_f
19
+ n_true_positives / (n_true_positives + n_false_positives)
20
+ end
21
+ end
22
+
23
+ # @!visibility private
24
+ def recall_each_class(y_true, y_pred)
25
+ y_true.sort.to_a.uniq.map do |label|
26
+ target_positions = y_true.eq(label)
27
+ next 0.0 if y_pred[target_positions].empty?
28
+ n_true_positives = Numo::Int32.cast(y_true[target_positions].eq(y_pred[target_positions])).sum.to_f
29
+ n_false_negatives = Numo::Int32.cast(y_true[target_positions].ne(y_pred[target_positions])).sum.to_f
30
+ n_true_positives / (n_true_positives + n_false_negatives)
31
+ end
32
+ end
33
+
34
+ # @!visibility private
35
+ def f_score_each_class(y_true, y_pred)
36
+ precision_each_class(y_true, y_pred).zip(recall_each_class(y_true, y_pred)).map do |p, r|
37
+ next 0.0 if p.zero? && r.zero?
38
+ (2.0 * p * r) / (p + r)
39
+ end
40
+ end
41
+
42
+ # @!visibility private
43
+ def micro_average_precision(y_true, y_pred)
44
+ evaluated_values = y_true.sort.to_a.uniq.map do |label|
45
+ target_positions = y_pred.eq(label)
46
+ next [0.0, 0.0] if y_pred[target_positions].empty?
47
+ n_true_positives = Numo::Int32.cast(y_true[target_positions].eq(y_pred[target_positions])).sum.to_f
48
+ n_false_positives = Numo::Int32.cast(y_true[target_positions].ne(y_pred[target_positions])).sum.to_f
49
+ [n_true_positives, n_true_positives + n_false_positives]
50
+ end
51
+ res = evaluated_values.transpose.map { |v| v.inject(:+) }
52
+ res.first / res.last
53
+ end
54
+
55
+ # @!visibility private
56
+ def micro_average_recall(y_true, y_pred)
57
+ evaluated_values = y_true.sort.to_a.uniq.map do |label|
58
+ target_positions = y_true.eq(label)
59
+ next 0.0 if y_pred[target_positions].empty?
60
+ n_true_positives = Numo::Int32.cast(y_true[target_positions].eq(y_pred[target_positions])).sum.to_f
61
+ n_false_negatives = Numo::Int32.cast(y_true[target_positions].ne(y_pred[target_positions])).sum.to_f
62
+ [n_true_positives, n_true_positives + n_false_negatives]
63
+ end
64
+ res = evaluated_values.transpose.map { |v| v.inject(:+) }
65
+ res.first / res.last
66
+ end
67
+
68
+ # @!visibility private
69
+ def micro_average_f_score(y_true, y_pred)
70
+ p = micro_average_precision(y_true, y_pred)
71
+ r = micro_average_recall(y_true, y_pred)
72
+ (2.0 * p * r) / (p + r)
73
+ end
74
+
75
+ # @!visibility private
76
+ def macro_average_precision(y_true, y_pred)
77
+ precision_each_class(y_true, y_pred).inject(:+) / y_true.to_a.uniq.size
78
+ end
79
+
80
+ # @!visibility private
81
+ def macro_average_recall(y_true, y_pred)
82
+ recall_each_class(y_true, y_pred).inject(:+) / y_true.to_a.uniq.size
83
+ end
84
+
85
+ # @!visibility private
86
+ def macro_average_f_score(y_true, y_pred)
87
+ f_score_each_class(y_true, y_pred).inject(:+) / y_true.to_a.uniq.size
88
+ end
89
+ end
90
+ end
91
+ end
@@ -0,0 +1,40 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/evaluator'
4
+
5
+ module Rumale
6
+ module EvaluationMeasure
7
+ # Purity is a class that calculates the purity of cluatering results.
8
+ #
9
+ # @example
10
+ # evaluator = Rumale::EvaluationMeasure::Purity.new
11
+ # puts evaluator.score(ground_truth, predicted)
12
+ #
13
+ # *Reference*
14
+ # - C D. Manning, P. Raghavan, and H. Schutze, "Introduction to Information Retrieval," Cambridge University Press., 2008.
15
+ class Purity
16
+ include Base::Evaluator
17
+
18
+ # Calculate purity
19
+ #
20
+ # @param y_true [Numo::Int32] (shape: [n_samples]) Ground truth labels.
21
+ # @param y_pred [Numo::Int32] (shape: [n_samples]) Predicted cluster labels.
22
+ # @return [Float] Purity
23
+ def score(y_true, y_pred)
24
+ check_label_array(y_true)
25
+ check_label_array(y_pred)
26
+ # initiazlie some variables.
27
+ purity = 0
28
+ n_samples = y_pred.size
29
+ class_ids = y_true.to_a.uniq
30
+ cluster_ids = y_pred.to_a.uniq
31
+ # calculate purity.
32
+ cluster_ids.each do |k|
33
+ pr_sample_ids = y_pred.eq(k).where.to_a
34
+ purity += class_ids.map { |j| (pr_sample_ids & y_true.eq(j).where.to_a).size }.max
35
+ end
36
+ purity.fdiv(n_samples)
37
+ end
38
+ end
39
+ end
40
+ end
@@ -0,0 +1,43 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/evaluator'
4
+ require 'rumale/evaluation_measure/precision_recall'
5
+
6
+ module Rumale
7
+ module EvaluationMeasure
8
+ # R2Score is a class that calculates the coefficient of determination for the predicted values.
9
+ #
10
+ # @example
11
+ # evaluator = Rumale::EvaluationMeasure::R2Score.new
12
+ # puts evaluator.score(ground_truth, predicted)
13
+ class R2Score
14
+ include Base::Evaluator
15
+
16
+ # Create a new evaluation measure calculater for coefficient of determination.
17
+ def initialize; end
18
+
19
+ # Calculate the coefficient of determination.
20
+ #
21
+ # @param y_true [Numo::DFloat] (shape: [n_samples, n_outputs]) Ground truth target values.
22
+ # @param y_pred [Numo::DFloat] (shape: [n_samples, n_outputs]) Estimated taget values.
23
+ # @return [Float] Coefficient of determination
24
+ def score(y_true, y_pred)
25
+ check_tvalue_array(y_true)
26
+ check_tvalue_array(y_pred)
27
+ raise ArgumentError, 'Expect to have the same size both y_true and y_pred.' unless y_true.shape == y_pred.shape
28
+
29
+ n_samples, n_outputs = y_true.shape
30
+ numerator = ((y_true - y_pred)**2).sum(0)
31
+ yt_mean = y_true.sum(0) / n_samples
32
+ denominator = ((y_true - yt_mean)**2).sum(0)
33
+ if n_outputs.nil?
34
+ denominator.zero? ? 0.0 : 1.0 - numerator / denominator
35
+ else
36
+ scores = 1 - numerator / denominator
37
+ scores[denominator.eq(0)] = 0.0
38
+ scores.sum / scores.size
39
+ end
40
+ end
41
+ end
42
+ end
43
+ end