rumale 0.8.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. checksums.yaml +7 -0
  2. data/.coveralls.yml +1 -0
  3. data/.gitignore +20 -0
  4. data/.rspec +3 -0
  5. data/.rubocop.yml +47 -0
  6. data/.rubocop_todo.yml +58 -0
  7. data/.travis.yml +13 -0
  8. data/CHANGELOG.md +2 -0
  9. data/CODE_OF_CONDUCT.md +74 -0
  10. data/Gemfile +4 -0
  11. data/LICENSE.txt +23 -0
  12. data/README.md +175 -0
  13. data/Rakefile +6 -0
  14. data/bin/console +14 -0
  15. data/bin/setup +8 -0
  16. data/lib/rumale.rb +70 -0
  17. data/lib/rumale/base/base_estimator.rb +13 -0
  18. data/lib/rumale/base/classifier.rb +36 -0
  19. data/lib/rumale/base/cluster_analyzer.rb +31 -0
  20. data/lib/rumale/base/evaluator.rb +17 -0
  21. data/lib/rumale/base/regressor.rb +36 -0
  22. data/lib/rumale/base/splitter.rb +21 -0
  23. data/lib/rumale/base/transformer.rb +22 -0
  24. data/lib/rumale/clustering/dbscan.rb +125 -0
  25. data/lib/rumale/clustering/k_means.rb +138 -0
  26. data/lib/rumale/dataset.rb +110 -0
  27. data/lib/rumale/decomposition/nmf.rb +141 -0
  28. data/lib/rumale/decomposition/pca.rb +148 -0
  29. data/lib/rumale/ensemble/ada_boost_classifier.rb +196 -0
  30. data/lib/rumale/ensemble/ada_boost_regressor.rb +178 -0
  31. data/lib/rumale/ensemble/random_forest_classifier.rb +180 -0
  32. data/lib/rumale/ensemble/random_forest_regressor.rb +141 -0
  33. data/lib/rumale/evaluation_measure/accuracy.rb +29 -0
  34. data/lib/rumale/evaluation_measure/f_score.rb +50 -0
  35. data/lib/rumale/evaluation_measure/log_loss.rb +45 -0
  36. data/lib/rumale/evaluation_measure/mean_absolute_error.rb +29 -0
  37. data/lib/rumale/evaluation_measure/mean_squared_error.rb +29 -0
  38. data/lib/rumale/evaluation_measure/normalized_mutual_information.rb +62 -0
  39. data/lib/rumale/evaluation_measure/precision.rb +50 -0
  40. data/lib/rumale/evaluation_measure/precision_recall.rb +91 -0
  41. data/lib/rumale/evaluation_measure/purity.rb +40 -0
  42. data/lib/rumale/evaluation_measure/r2_score.rb +43 -0
  43. data/lib/rumale/evaluation_measure/recall.rb +50 -0
  44. data/lib/rumale/kernel_approximation/rbf.rb +121 -0
  45. data/lib/rumale/kernel_machine/kernel_svc.rb +193 -0
  46. data/lib/rumale/linear_model/base_linear_model.rb +89 -0
  47. data/lib/rumale/linear_model/lasso.rb +136 -0
  48. data/lib/rumale/linear_model/linear_regression.rb +110 -0
  49. data/lib/rumale/linear_model/logistic_regression.rb +159 -0
  50. data/lib/rumale/linear_model/ridge.rb +110 -0
  51. data/lib/rumale/linear_model/svc.rb +183 -0
  52. data/lib/rumale/linear_model/svr.rb +122 -0
  53. data/lib/rumale/model_selection/cross_validation.rb +123 -0
  54. data/lib/rumale/model_selection/grid_search_cv.rb +247 -0
  55. data/lib/rumale/model_selection/k_fold.rb +76 -0
  56. data/lib/rumale/model_selection/stratified_k_fold.rb +94 -0
  57. data/lib/rumale/multiclass/one_vs_rest_classifier.rb +100 -0
  58. data/lib/rumale/naive_bayes/naive_bayes.rb +315 -0
  59. data/lib/rumale/nearest_neighbors/k_neighbors_classifier.rb +111 -0
  60. data/lib/rumale/nearest_neighbors/k_neighbors_regressor.rb +93 -0
  61. data/lib/rumale/optimizer/nadam.rb +90 -0
  62. data/lib/rumale/optimizer/rmsprop.rb +69 -0
  63. data/lib/rumale/optimizer/sgd.rb +65 -0
  64. data/lib/rumale/optimizer/yellow_fin.rb +144 -0
  65. data/lib/rumale/pairwise_metric.rb +91 -0
  66. data/lib/rumale/pipeline/pipeline.rb +197 -0
  67. data/lib/rumale/polynomial_model/base_factorization_machine.rb +99 -0
  68. data/lib/rumale/polynomial_model/factorization_machine_classifier.rb +197 -0
  69. data/lib/rumale/polynomial_model/factorization_machine_regressor.rb +131 -0
  70. data/lib/rumale/preprocessing/l2_normalizer.rb +62 -0
  71. data/lib/rumale/preprocessing/label_encoder.rb +94 -0
  72. data/lib/rumale/preprocessing/min_max_scaler.rb +92 -0
  73. data/lib/rumale/preprocessing/one_hot_encoder.rb +98 -0
  74. data/lib/rumale/preprocessing/standard_scaler.rb +86 -0
  75. data/lib/rumale/probabilistic_output.rb +112 -0
  76. data/lib/rumale/tree/base_decision_tree.rb +153 -0
  77. data/lib/rumale/tree/decision_tree_classifier.rb +163 -0
  78. data/lib/rumale/tree/decision_tree_regressor.rb +135 -0
  79. data/lib/rumale/tree/node.rb +70 -0
  80. data/lib/rumale/utils.rb +37 -0
  81. data/lib/rumale/validation.rb +79 -0
  82. data/lib/rumale/values.rb +13 -0
  83. data/lib/rumale/version.rb +6 -0
  84. data/rumale.gemspec +41 -0
  85. metadata +204 -0
@@ -0,0 +1,29 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/evaluator'
4
+
5
+ module Rumale
6
+ # This module consists of the classes for model evaluation.
7
+ module EvaluationMeasure
8
+ # Accuracy is a class that calculates the accuracy of classifier from the predicted labels.
9
+ #
10
+ # @example
11
+ # evaluator = Rumale::EvaluationMeasure::Accuracy.new
12
+ # puts evaluator.score(ground_truth, predicted)
13
+ class Accuracy
14
+ include Base::Evaluator
15
+
16
+ # Calculate mean accuracy.
17
+ #
18
+ # @param y_true [Numo::Int32] (shape: [n_samples]) Ground truth labels.
19
+ # @param y_pred [Numo::Int32] (shape: [n_samples]) Predicted labels.
20
+ # @return [Float] Mean accuracy
21
+ def score(y_true, y_pred)
22
+ check_label_array(y_true)
23
+ check_label_array(y_pred)
24
+
25
+ (y_true.to_a.map.with_index { |label, n| label == y_pred[n] ? 1 : 0 }).inject(:+) / y_true.size.to_f
26
+ end
27
+ end
28
+ end
29
+ end
@@ -0,0 +1,50 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/evaluator'
4
+ require 'rumale/evaluation_measure/precision_recall'
5
+
6
+ module Rumale
7
+ # This module consists of the classes for model evaluation.
8
+ module EvaluationMeasure
9
+ # FScore is a class that calculates the F1-score of the predicted labels.
10
+ #
11
+ # @example
12
+ # evaluator = Rumale::EvaluationMeasure::FScore.new
13
+ # puts evaluator.score(ground_truth, predicted)
14
+ class FScore
15
+ include Base::Evaluator
16
+ include EvaluationMeasure::PrecisionRecall
17
+
18
+ # Return the average type for calculation of F1-score.
19
+ # @return [String] ('binary', 'micro', 'macro')
20
+ attr_reader :average
21
+
22
+ # Create a new evaluation measure calculater for F1-score.
23
+ #
24
+ # @param average [String] The average type ('binary', 'micro', 'macro')
25
+ def initialize(average: 'binary')
26
+ check_params_string(average: average)
27
+ @average = average
28
+ end
29
+
30
+ # Calculate average F1-score
31
+ #
32
+ # @param y_true [Numo::Int32] (shape: [n_samples]) Ground truth labels.
33
+ # @param y_pred [Numo::Int32] (shape: [n_samples]) Predicted labels.
34
+ # @return [Float] Average F1-score
35
+ def score(y_true, y_pred)
36
+ check_label_array(y_true)
37
+ check_label_array(y_pred)
38
+
39
+ case @average
40
+ when 'binary'
41
+ f_score_each_class(y_true, y_pred).last
42
+ when 'micro'
43
+ micro_average_f_score(y_true, y_pred)
44
+ when 'macro'
45
+ macro_average_f_score(y_true, y_pred)
46
+ end
47
+ end
48
+ end
49
+ end
50
+ end
@@ -0,0 +1,45 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/evaluator'
4
+ require 'rumale/preprocessing/one_hot_encoder'
5
+
6
+ module Rumale
7
+ module EvaluationMeasure
8
+ # LogLoss is a class that calculates the logarithmic loss of predicted class probability.
9
+ #
10
+ # @example
11
+ # evaluator = Rumale::EvaluationMeasure::LogLoss.new
12
+ # puts evaluator.score(ground_truth, predicted)
13
+ class LogLoss
14
+ include Base::Evaluator
15
+
16
+ # Calculate mean logarithmic loss.
17
+ # If both y_true and y_pred are array (both shapes are [n_samples]), this method calculates
18
+ # mean logarithmic loss for binary classification.
19
+ #
20
+ # @param y_true [Numo::Int32] (shape: [n_samples]) Ground truth labels.
21
+ # @param y_pred [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted class probability.
22
+ # @param eps [Float] A small value close to zero to avoid outputting infinity in logarithmic calcuation.
23
+ # @return [Float] mean logarithmic loss
24
+ def score(y_true, y_pred, eps = 1e-15)
25
+ check_params_type(Numo::Int32, y_true: y_true)
26
+ check_params_type(Numo::DFloat, y_pred: y_pred)
27
+
28
+ n_samples, n_classes = y_pred.shape
29
+ clipped_p = y_pred.clip(eps, 1 - eps)
30
+
31
+ log_loss = if n_classes.nil?
32
+ negative_label = y_true.to_a.uniq.min
33
+ bin_y_true = Numo::DFloat.cast(y_true.ne(negative_label))
34
+ -(bin_y_true * Numo::NMath.log(clipped_p) + (1 - bin_y_true) * Numo::NMath.log(1 - clipped_p))
35
+ else
36
+ encoder = Rumale::Preprocessing::OneHotEncoder.new
37
+ encoded_y_true = encoder.fit_transform(y_true)
38
+ clipped_p /= clipped_p.sum(1).expand_dims(1)
39
+ -(encoded_y_true * Numo::NMath.log(clipped_p)).sum(1)
40
+ end
41
+ log_loss.sum / n_samples
42
+ end
43
+ end
44
+ end
45
+ end
@@ -0,0 +1,29 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/evaluator'
4
+
5
+ module Rumale
6
+ module EvaluationMeasure
7
+ # MeanAbsoluteError is a class that calculates the mean absolute error.
8
+ #
9
+ # @example
10
+ # evaluator = Rumale::EvaluationMeasure::MeanAbsoluteError.new
11
+ # puts evaluator.score(ground_truth, predicted)
12
+ class MeanAbsoluteError
13
+ include Base::Evaluator
14
+
15
+ # Calculate mean absolute error.
16
+ #
17
+ # @param y_true [Numo::DFloat] (shape: [n_samples, n_outputs]) Ground truth target values.
18
+ # @param y_pred [Numo::DFloat] (shape: [n_samples, n_outputs]) Estimated target values.
19
+ # @return [Float] Mean absolute error
20
+ def score(y_true, y_pred)
21
+ check_tvalue_array(y_true)
22
+ check_tvalue_array(y_pred)
23
+ raise ArgumentError, 'Expect to have the same size both y_true and y_pred.' unless y_true.shape == y_pred.shape
24
+
25
+ (y_true - y_pred).abs.mean
26
+ end
27
+ end
28
+ end
29
+ end
@@ -0,0 +1,29 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/evaluator'
4
+
5
+ module Rumale
6
+ module EvaluationMeasure
7
+ # MeanSquaredError is a class that calculates the mean squared error.
8
+ #
9
+ # @example
10
+ # evaluator = Rumale::EvaluationMeasure::MeanSquaredError.new
11
+ # puts evaluator.score(ground_truth, predicted)
12
+ class MeanSquaredError
13
+ include Base::Evaluator
14
+
15
+ # Calculate mean squared error.
16
+ #
17
+ # @param y_true [Numo::DFloat] (shape: [n_samples, n_outputs]) Ground truth target values.
18
+ # @param y_pred [Numo::DFloat] (shape: [n_samples, n_outputs]) Estimated target values.
19
+ # @return [Float] Mean squared error
20
+ def score(y_true, y_pred)
21
+ check_tvalue_array(y_true)
22
+ check_tvalue_array(y_pred)
23
+ raise ArgumentError, 'Expect to have the same size both y_true and y_pred.' unless y_true.shape == y_pred.shape
24
+
25
+ ((y_true - y_pred)**2).mean
26
+ end
27
+ end
28
+ end
29
+ end
@@ -0,0 +1,62 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/evaluator'
4
+
5
+ module Rumale
6
+ module EvaluationMeasure
7
+ # NormalizedMutualInformation is a class that calculates the normalized mutual information of cluatering results.
8
+ #
9
+ # @example
10
+ # evaluator = Rumale::EvaluationMeasure::NormalizedMutualInformation.new
11
+ # puts evaluator.score(ground_truth, predicted)
12
+ #
13
+ # *Reference*
14
+ # - C D. Manning, P. Raghavan, and H. Schutze, "Introduction to Information Retrieval," Cambridge University Press., 2008.
15
+ # - N X. Vinh, J. Epps, and J. Bailey, "Information Theoretic Measures for Clusterings Comparison: Variants, Properties, Normalization and Correction for Chance," J. Machine Learning Research, vol. 11, pp. 2837--1854, 2010.
16
+ class NormalizedMutualInformation
17
+ include Base::Evaluator
18
+
19
+ # Calculate noramlzied mutual information
20
+ #
21
+ # @param y_true [Numo::Int32] (shape: [n_samples]) Ground truth labels.
22
+ # @param y_pred [Numo::Int32] (shape: [n_samples]) Predicted cluster labels.
23
+ # @return [Float] Normalized mutual information
24
+ def score(y_true, y_pred)
25
+ check_label_array(y_true)
26
+ check_label_array(y_pred)
27
+ # initiazlie some variables.
28
+ mutual_information = 0.0
29
+ n_samples = y_pred.size
30
+ class_ids = y_true.to_a.uniq
31
+ cluster_ids = y_pred.to_a.uniq
32
+ # calculate entropy.
33
+ class_entropy = -1.0 * class_ids.map do |k|
34
+ ratio = y_true.eq(k).count.fdiv(n_samples)
35
+ ratio * Math.log(ratio)
36
+ end.reduce(:+)
37
+ return 0.0 if class_entropy.zero?
38
+ cluster_entropy = -1.0 * cluster_ids.map do |k|
39
+ ratio = y_pred.eq(k).count.fdiv(n_samples)
40
+ ratio * Math.log(ratio)
41
+ end.reduce(:+)
42
+ return 0.0 if cluster_entropy.zero?
43
+ # calculate mutual information.
44
+ cluster_ids.map do |k|
45
+ pr_sample_ids = y_pred.eq(k).where.to_a
46
+ n_pr_samples = pr_sample_ids.size
47
+ class_ids.map do |j|
48
+ tr_sample_ids = y_true.eq(j).where.to_a
49
+ n_tr_samples = tr_sample_ids.size
50
+ n_intr_samples = (pr_sample_ids & tr_sample_ids).size
51
+ if n_intr_samples.positive?
52
+ mutual_information +=
53
+ n_intr_samples.fdiv(n_samples) * Math.log((n_samples * n_intr_samples).fdiv(n_pr_samples * n_tr_samples))
54
+ end
55
+ end
56
+ end
57
+ # return normalized mutual information.
58
+ mutual_information / Math.sqrt(class_entropy * cluster_entropy)
59
+ end
60
+ end
61
+ end
62
+ end
@@ -0,0 +1,50 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/evaluator'
4
+ require 'rumale/evaluation_measure/precision_recall'
5
+
6
+ module Rumale
7
+ # This module consists of the classes for model evaluation.
8
+ module EvaluationMeasure
9
+ # Precision is a class that calculates the preicision of the predicted labels.
10
+ #
11
+ # @example
12
+ # evaluator = Rumale::EvaluationMeasure::Precision.new
13
+ # puts evaluator.score(ground_truth, predicted)
14
+ class Precision
15
+ include Base::Evaluator
16
+ include EvaluationMeasure::PrecisionRecall
17
+
18
+ # Return the average type for calculation of precision.
19
+ # @return [String] ('binary', 'micro', 'macro')
20
+ attr_reader :average
21
+
22
+ # Create a new evaluation measure calculater for precision score.
23
+ #
24
+ # @param average [String] The average type ('binary', 'micro', 'macro')
25
+ def initialize(average: 'binary')
26
+ check_params_string(average: average)
27
+ @average = average
28
+ end
29
+
30
+ # Calculate average precision.
31
+ #
32
+ # @param y_true [Numo::Int32] (shape: [n_samples]) Ground truth labels.
33
+ # @param y_pred [Numo::Int32] (shape: [n_samples]) Predicted labels.
34
+ # @return [Float] Average precision
35
+ def score(y_true, y_pred)
36
+ check_label_array(y_true)
37
+ check_label_array(y_pred)
38
+
39
+ case @average
40
+ when 'binary'
41
+ precision_each_class(y_true, y_pred).last
42
+ when 'micro'
43
+ micro_average_precision(y_true, y_pred)
44
+ when 'macro'
45
+ macro_average_precision(y_true, y_pred)
46
+ end
47
+ end
48
+ end
49
+ end
50
+ end
@@ -0,0 +1,91 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/evaluator'
4
+
5
+ module Rumale
6
+ # This module consists of the classes for model evaluation.
7
+ module EvaluationMeasure
8
+ # @!visibility private
9
+ module PrecisionRecall
10
+ module_function
11
+
12
+ # @!visibility private
13
+ def precision_each_class(y_true, y_pred)
14
+ y_true.sort.to_a.uniq.map do |label|
15
+ target_positions = y_pred.eq(label)
16
+ next 0.0 if y_pred[target_positions].empty?
17
+ n_true_positives = Numo::Int32.cast(y_true[target_positions].eq(y_pred[target_positions])).sum.to_f
18
+ n_false_positives = Numo::Int32.cast(y_true[target_positions].ne(y_pred[target_positions])).sum.to_f
19
+ n_true_positives / (n_true_positives + n_false_positives)
20
+ end
21
+ end
22
+
23
+ # @!visibility private
24
+ def recall_each_class(y_true, y_pred)
25
+ y_true.sort.to_a.uniq.map do |label|
26
+ target_positions = y_true.eq(label)
27
+ next 0.0 if y_pred[target_positions].empty?
28
+ n_true_positives = Numo::Int32.cast(y_true[target_positions].eq(y_pred[target_positions])).sum.to_f
29
+ n_false_negatives = Numo::Int32.cast(y_true[target_positions].ne(y_pred[target_positions])).sum.to_f
30
+ n_true_positives / (n_true_positives + n_false_negatives)
31
+ end
32
+ end
33
+
34
+ # @!visibility private
35
+ def f_score_each_class(y_true, y_pred)
36
+ precision_each_class(y_true, y_pred).zip(recall_each_class(y_true, y_pred)).map do |p, r|
37
+ next 0.0 if p.zero? && r.zero?
38
+ (2.0 * p * r) / (p + r)
39
+ end
40
+ end
41
+
42
+ # @!visibility private
43
+ def micro_average_precision(y_true, y_pred)
44
+ evaluated_values = y_true.sort.to_a.uniq.map do |label|
45
+ target_positions = y_pred.eq(label)
46
+ next [0.0, 0.0] if y_pred[target_positions].empty?
47
+ n_true_positives = Numo::Int32.cast(y_true[target_positions].eq(y_pred[target_positions])).sum.to_f
48
+ n_false_positives = Numo::Int32.cast(y_true[target_positions].ne(y_pred[target_positions])).sum.to_f
49
+ [n_true_positives, n_true_positives + n_false_positives]
50
+ end
51
+ res = evaluated_values.transpose.map { |v| v.inject(:+) }
52
+ res.first / res.last
53
+ end
54
+
55
+ # @!visibility private
56
+ def micro_average_recall(y_true, y_pred)
57
+ evaluated_values = y_true.sort.to_a.uniq.map do |label|
58
+ target_positions = y_true.eq(label)
59
+ next 0.0 if y_pred[target_positions].empty?
60
+ n_true_positives = Numo::Int32.cast(y_true[target_positions].eq(y_pred[target_positions])).sum.to_f
61
+ n_false_negatives = Numo::Int32.cast(y_true[target_positions].ne(y_pred[target_positions])).sum.to_f
62
+ [n_true_positives, n_true_positives + n_false_negatives]
63
+ end
64
+ res = evaluated_values.transpose.map { |v| v.inject(:+) }
65
+ res.first / res.last
66
+ end
67
+
68
+ # @!visibility private
69
+ def micro_average_f_score(y_true, y_pred)
70
+ p = micro_average_precision(y_true, y_pred)
71
+ r = micro_average_recall(y_true, y_pred)
72
+ (2.0 * p * r) / (p + r)
73
+ end
74
+
75
+ # @!visibility private
76
+ def macro_average_precision(y_true, y_pred)
77
+ precision_each_class(y_true, y_pred).inject(:+) / y_true.to_a.uniq.size
78
+ end
79
+
80
+ # @!visibility private
81
+ def macro_average_recall(y_true, y_pred)
82
+ recall_each_class(y_true, y_pred).inject(:+) / y_true.to_a.uniq.size
83
+ end
84
+
85
+ # @!visibility private
86
+ def macro_average_f_score(y_true, y_pred)
87
+ f_score_each_class(y_true, y_pred).inject(:+) / y_true.to_a.uniq.size
88
+ end
89
+ end
90
+ end
91
+ end
@@ -0,0 +1,40 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/evaluator'
4
+
5
+ module Rumale
6
+ module EvaluationMeasure
7
+ # Purity is a class that calculates the purity of cluatering results.
8
+ #
9
+ # @example
10
+ # evaluator = Rumale::EvaluationMeasure::Purity.new
11
+ # puts evaluator.score(ground_truth, predicted)
12
+ #
13
+ # *Reference*
14
+ # - C D. Manning, P. Raghavan, and H. Schutze, "Introduction to Information Retrieval," Cambridge University Press., 2008.
15
+ class Purity
16
+ include Base::Evaluator
17
+
18
+ # Calculate purity
19
+ #
20
+ # @param y_true [Numo::Int32] (shape: [n_samples]) Ground truth labels.
21
+ # @param y_pred [Numo::Int32] (shape: [n_samples]) Predicted cluster labels.
22
+ # @return [Float] Purity
23
+ def score(y_true, y_pred)
24
+ check_label_array(y_true)
25
+ check_label_array(y_pred)
26
+ # initiazlie some variables.
27
+ purity = 0
28
+ n_samples = y_pred.size
29
+ class_ids = y_true.to_a.uniq
30
+ cluster_ids = y_pred.to_a.uniq
31
+ # calculate purity.
32
+ cluster_ids.each do |k|
33
+ pr_sample_ids = y_pred.eq(k).where.to_a
34
+ purity += class_ids.map { |j| (pr_sample_ids & y_true.eq(j).where.to_a).size }.max
35
+ end
36
+ purity.fdiv(n_samples)
37
+ end
38
+ end
39
+ end
40
+ end
@@ -0,0 +1,43 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/evaluator'
4
+ require 'rumale/evaluation_measure/precision_recall'
5
+
6
+ module Rumale
7
+ module EvaluationMeasure
8
+ # R2Score is a class that calculates the coefficient of determination for the predicted values.
9
+ #
10
+ # @example
11
+ # evaluator = Rumale::EvaluationMeasure::R2Score.new
12
+ # puts evaluator.score(ground_truth, predicted)
13
+ class R2Score
14
+ include Base::Evaluator
15
+
16
+ # Create a new evaluation measure calculater for coefficient of determination.
17
+ def initialize; end
18
+
19
+ # Calculate the coefficient of determination.
20
+ #
21
+ # @param y_true [Numo::DFloat] (shape: [n_samples, n_outputs]) Ground truth target values.
22
+ # @param y_pred [Numo::DFloat] (shape: [n_samples, n_outputs]) Estimated taget values.
23
+ # @return [Float] Coefficient of determination
24
+ def score(y_true, y_pred)
25
+ check_tvalue_array(y_true)
26
+ check_tvalue_array(y_pred)
27
+ raise ArgumentError, 'Expect to have the same size both y_true and y_pred.' unless y_true.shape == y_pred.shape
28
+
29
+ n_samples, n_outputs = y_true.shape
30
+ numerator = ((y_true - y_pred)**2).sum(0)
31
+ yt_mean = y_true.sum(0) / n_samples
32
+ denominator = ((y_true - yt_mean)**2).sum(0)
33
+ if n_outputs.nil?
34
+ denominator.zero? ? 0.0 : 1.0 - numerator / denominator
35
+ else
36
+ scores = 1 - numerator / denominator
37
+ scores[denominator.eq(0)] = 0.0
38
+ scores.sum / scores.size
39
+ end
40
+ end
41
+ end
42
+ end
43
+ end