rumale 0.8.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
 - data/.coveralls.yml +1 -0
 - data/.gitignore +20 -0
 - data/.rspec +3 -0
 - data/.rubocop.yml +47 -0
 - data/.rubocop_todo.yml +58 -0
 - data/.travis.yml +13 -0
 - data/CHANGELOG.md +2 -0
 - data/CODE_OF_CONDUCT.md +74 -0
 - data/Gemfile +4 -0
 - data/LICENSE.txt +23 -0
 - data/README.md +175 -0
 - data/Rakefile +6 -0
 - data/bin/console +14 -0
 - data/bin/setup +8 -0
 - data/lib/rumale.rb +70 -0
 - data/lib/rumale/base/base_estimator.rb +13 -0
 - data/lib/rumale/base/classifier.rb +36 -0
 - data/lib/rumale/base/cluster_analyzer.rb +31 -0
 - data/lib/rumale/base/evaluator.rb +17 -0
 - data/lib/rumale/base/regressor.rb +36 -0
 - data/lib/rumale/base/splitter.rb +21 -0
 - data/lib/rumale/base/transformer.rb +22 -0
 - data/lib/rumale/clustering/dbscan.rb +125 -0
 - data/lib/rumale/clustering/k_means.rb +138 -0
 - data/lib/rumale/dataset.rb +110 -0
 - data/lib/rumale/decomposition/nmf.rb +141 -0
 - data/lib/rumale/decomposition/pca.rb +148 -0
 - data/lib/rumale/ensemble/ada_boost_classifier.rb +196 -0
 - data/lib/rumale/ensemble/ada_boost_regressor.rb +178 -0
 - data/lib/rumale/ensemble/random_forest_classifier.rb +180 -0
 - data/lib/rumale/ensemble/random_forest_regressor.rb +141 -0
 - data/lib/rumale/evaluation_measure/accuracy.rb +29 -0
 - data/lib/rumale/evaluation_measure/f_score.rb +50 -0
 - data/lib/rumale/evaluation_measure/log_loss.rb +45 -0
 - data/lib/rumale/evaluation_measure/mean_absolute_error.rb +29 -0
 - data/lib/rumale/evaluation_measure/mean_squared_error.rb +29 -0
 - data/lib/rumale/evaluation_measure/normalized_mutual_information.rb +62 -0
 - data/lib/rumale/evaluation_measure/precision.rb +50 -0
 - data/lib/rumale/evaluation_measure/precision_recall.rb +91 -0
 - data/lib/rumale/evaluation_measure/purity.rb +40 -0
 - data/lib/rumale/evaluation_measure/r2_score.rb +43 -0
 - data/lib/rumale/evaluation_measure/recall.rb +50 -0
 - data/lib/rumale/kernel_approximation/rbf.rb +121 -0
 - data/lib/rumale/kernel_machine/kernel_svc.rb +193 -0
 - data/lib/rumale/linear_model/base_linear_model.rb +89 -0
 - data/lib/rumale/linear_model/lasso.rb +136 -0
 - data/lib/rumale/linear_model/linear_regression.rb +110 -0
 - data/lib/rumale/linear_model/logistic_regression.rb +159 -0
 - data/lib/rumale/linear_model/ridge.rb +110 -0
 - data/lib/rumale/linear_model/svc.rb +183 -0
 - data/lib/rumale/linear_model/svr.rb +122 -0
 - data/lib/rumale/model_selection/cross_validation.rb +123 -0
 - data/lib/rumale/model_selection/grid_search_cv.rb +247 -0
 - data/lib/rumale/model_selection/k_fold.rb +76 -0
 - data/lib/rumale/model_selection/stratified_k_fold.rb +94 -0
 - data/lib/rumale/multiclass/one_vs_rest_classifier.rb +100 -0
 - data/lib/rumale/naive_bayes/naive_bayes.rb +315 -0
 - data/lib/rumale/nearest_neighbors/k_neighbors_classifier.rb +111 -0
 - data/lib/rumale/nearest_neighbors/k_neighbors_regressor.rb +93 -0
 - data/lib/rumale/optimizer/nadam.rb +90 -0
 - data/lib/rumale/optimizer/rmsprop.rb +69 -0
 - data/lib/rumale/optimizer/sgd.rb +65 -0
 - data/lib/rumale/optimizer/yellow_fin.rb +144 -0
 - data/lib/rumale/pairwise_metric.rb +91 -0
 - data/lib/rumale/pipeline/pipeline.rb +197 -0
 - data/lib/rumale/polynomial_model/base_factorization_machine.rb +99 -0
 - data/lib/rumale/polynomial_model/factorization_machine_classifier.rb +197 -0
 - data/lib/rumale/polynomial_model/factorization_machine_regressor.rb +131 -0
 - data/lib/rumale/preprocessing/l2_normalizer.rb +62 -0
 - data/lib/rumale/preprocessing/label_encoder.rb +94 -0
 - data/lib/rumale/preprocessing/min_max_scaler.rb +92 -0
 - data/lib/rumale/preprocessing/one_hot_encoder.rb +98 -0
 - data/lib/rumale/preprocessing/standard_scaler.rb +86 -0
 - data/lib/rumale/probabilistic_output.rb +112 -0
 - data/lib/rumale/tree/base_decision_tree.rb +153 -0
 - data/lib/rumale/tree/decision_tree_classifier.rb +163 -0
 - data/lib/rumale/tree/decision_tree_regressor.rb +135 -0
 - data/lib/rumale/tree/node.rb +70 -0
 - data/lib/rumale/utils.rb +37 -0
 - data/lib/rumale/validation.rb +79 -0
 - data/lib/rumale/values.rb +13 -0
 - data/lib/rumale/version.rb +6 -0
 - data/rumale.gemspec +41 -0
 - metadata +204 -0
 
| 
         @@ -0,0 +1,29 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            # frozen_string_literal: true
         
     | 
| 
      
 2 
     | 
    
         
            +
             
     | 
| 
      
 3 
     | 
    
         
            +
            require 'rumale/base/evaluator'
         
     | 
| 
      
 4 
     | 
    
         
            +
             
     | 
| 
      
 5 
     | 
    
         
            +
            module Rumale
         
     | 
| 
      
 6 
     | 
    
         
            +
              # This module consists of the classes for model evaluation.
         
     | 
| 
      
 7 
     | 
    
         
            +
              module EvaluationMeasure
         
     | 
| 
      
 8 
     | 
    
         
            +
                # Accuracy is a class that calculates the accuracy of classifier from the predicted labels.
         
     | 
| 
      
 9 
     | 
    
         
            +
                #
         
     | 
| 
      
 10 
     | 
    
         
            +
                # @example
         
     | 
| 
      
 11 
     | 
    
         
            +
                #   evaluator = Rumale::EvaluationMeasure::Accuracy.new
         
     | 
| 
      
 12 
     | 
    
         
            +
                #   puts evaluator.score(ground_truth, predicted)
         
     | 
| 
      
 13 
     | 
    
         
            +
                class Accuracy
         
     | 
| 
      
 14 
     | 
    
         
            +
                  include Base::Evaluator
         
     | 
| 
      
 15 
     | 
    
         
            +
             
     | 
| 
      
 16 
     | 
    
         
            +
                  # Calculate mean accuracy.
         
     | 
| 
      
 17 
     | 
    
         
            +
                  #
         
     | 
| 
      
 18 
     | 
    
         
            +
                  # @param y_true [Numo::Int32] (shape: [n_samples]) Ground truth labels.
         
     | 
| 
      
 19 
     | 
    
         
            +
                  # @param y_pred [Numo::Int32] (shape: [n_samples]) Predicted labels.
         
     | 
| 
      
 20 
     | 
    
         
            +
                  # @return [Float] Mean accuracy
         
     | 
| 
      
 21 
     | 
    
         
            +
                  def score(y_true, y_pred)
         
     | 
| 
      
 22 
     | 
    
         
            +
                    check_label_array(y_true)
         
     | 
| 
      
 23 
     | 
    
         
            +
                    check_label_array(y_pred)
         
     | 
| 
      
 24 
     | 
    
         
            +
             
     | 
| 
      
 25 
     | 
    
         
            +
                    (y_true.to_a.map.with_index { |label, n| label == y_pred[n] ? 1 : 0 }).inject(:+) / y_true.size.to_f
         
     | 
| 
      
 26 
     | 
    
         
            +
                  end
         
     | 
| 
      
 27 
     | 
    
         
            +
                end
         
     | 
| 
      
 28 
     | 
    
         
            +
              end
         
     | 
| 
      
 29 
     | 
    
         
            +
            end
         
     | 
| 
         @@ -0,0 +1,50 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            # frozen_string_literal: true
         
     | 
| 
      
 2 
     | 
    
         
            +
             
     | 
| 
      
 3 
     | 
    
         
            +
            require 'rumale/base/evaluator'
         
     | 
| 
      
 4 
     | 
    
         
            +
            require 'rumale/evaluation_measure/precision_recall'
         
     | 
| 
      
 5 
     | 
    
         
            +
             
     | 
| 
      
 6 
     | 
    
         
            +
            module Rumale
         
     | 
| 
      
 7 
     | 
    
         
            +
              # This module consists of the classes for model evaluation.
         
     | 
| 
      
 8 
     | 
    
         
            +
              module EvaluationMeasure
         
     | 
| 
      
 9 
     | 
    
         
            +
                # FScore is a class that calculates the F1-score of the predicted labels.
         
     | 
| 
      
 10 
     | 
    
         
            +
                #
         
     | 
| 
      
 11 
     | 
    
         
            +
                # @example
         
     | 
| 
      
 12 
     | 
    
         
            +
                #   evaluator = Rumale::EvaluationMeasure::FScore.new
         
     | 
| 
      
 13 
     | 
    
         
            +
                #   puts evaluator.score(ground_truth, predicted)
         
     | 
| 
      
 14 
     | 
    
         
            +
                class FScore
         
     | 
| 
      
 15 
     | 
    
         
            +
                  include Base::Evaluator
         
     | 
| 
      
 16 
     | 
    
         
            +
                  include EvaluationMeasure::PrecisionRecall
         
     | 
| 
      
 17 
     | 
    
         
            +
             
     | 
| 
      
 18 
     | 
    
         
            +
                  # Return the average type for calculation of F1-score.
         
     | 
| 
      
 19 
     | 
    
         
            +
                  # @return [String] ('binary', 'micro', 'macro')
         
     | 
| 
      
 20 
     | 
    
         
            +
                  attr_reader :average
         
     | 
| 
      
 21 
     | 
    
         
            +
             
     | 
| 
      
 22 
     | 
    
         
            +
                  # Create a new evaluation measure calculater for F1-score.
         
     | 
| 
      
 23 
     | 
    
         
            +
                  #
         
     | 
| 
      
 24 
     | 
    
         
            +
                  # @param average [String] The average type ('binary', 'micro', 'macro')
         
     | 
| 
      
 25 
     | 
    
         
            +
                  def initialize(average: 'binary')
         
     | 
| 
      
 26 
     | 
    
         
            +
                    check_params_string(average: average)
         
     | 
| 
      
 27 
     | 
    
         
            +
                    @average = average
         
     | 
| 
      
 28 
     | 
    
         
            +
                  end
         
     | 
| 
      
 29 
     | 
    
         
            +
             
     | 
| 
      
 30 
     | 
    
         
            +
                  # Calculate average F1-score
         
     | 
| 
      
 31 
     | 
    
         
            +
                  #
         
     | 
| 
      
 32 
     | 
    
         
            +
                  # @param y_true [Numo::Int32] (shape: [n_samples]) Ground truth labels.
         
     | 
| 
      
 33 
     | 
    
         
            +
                  # @param y_pred [Numo::Int32] (shape: [n_samples]) Predicted labels.
         
     | 
| 
      
 34 
     | 
    
         
            +
                  # @return [Float] Average F1-score
         
     | 
| 
      
 35 
     | 
    
         
            +
                  def score(y_true, y_pred)
         
     | 
| 
      
 36 
     | 
    
         
            +
                    check_label_array(y_true)
         
     | 
| 
      
 37 
     | 
    
         
            +
                    check_label_array(y_pred)
         
     | 
| 
      
 38 
     | 
    
         
            +
             
     | 
| 
      
 39 
     | 
    
         
            +
                    case @average
         
     | 
| 
      
 40 
     | 
    
         
            +
                    when 'binary'
         
     | 
| 
      
 41 
     | 
    
         
            +
                      f_score_each_class(y_true, y_pred).last
         
     | 
| 
      
 42 
     | 
    
         
            +
                    when 'micro'
         
     | 
| 
      
 43 
     | 
    
         
            +
                      micro_average_f_score(y_true, y_pred)
         
     | 
| 
      
 44 
     | 
    
         
            +
                    when 'macro'
         
     | 
| 
      
 45 
     | 
    
         
            +
                      macro_average_f_score(y_true, y_pred)
         
     | 
| 
      
 46 
     | 
    
         
            +
                    end
         
     | 
| 
      
 47 
     | 
    
         
            +
                  end
         
     | 
| 
      
 48 
     | 
    
         
            +
                end
         
     | 
| 
      
 49 
     | 
    
         
            +
              end
         
     | 
| 
      
 50 
     | 
    
         
            +
            end
         
     | 
| 
         @@ -0,0 +1,45 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            # frozen_string_literal: true
         
     | 
| 
      
 2 
     | 
    
         
            +
             
     | 
| 
      
 3 
     | 
    
         
            +
            require 'rumale/base/evaluator'
         
     | 
| 
      
 4 
     | 
    
         
            +
            require 'rumale/preprocessing/one_hot_encoder'
         
     | 
| 
      
 5 
     | 
    
         
            +
             
     | 
| 
      
 6 
     | 
    
         
            +
            module Rumale
         
     | 
| 
      
 7 
     | 
    
         
            +
              module EvaluationMeasure
         
     | 
| 
      
 8 
     | 
    
         
            +
                # LogLoss is a class that calculates the logarithmic loss of predicted class probability.
         
     | 
| 
      
 9 
     | 
    
         
            +
                #
         
     | 
| 
      
 10 
     | 
    
         
            +
                # @example
         
     | 
| 
      
 11 
     | 
    
         
            +
                #   evaluator = Rumale::EvaluationMeasure::LogLoss.new
         
     | 
| 
      
 12 
     | 
    
         
            +
                #   puts evaluator.score(ground_truth, predicted)
         
     | 
| 
      
 13 
     | 
    
         
            +
                class LogLoss
         
     | 
| 
      
 14 
     | 
    
         
            +
                  include Base::Evaluator
         
     | 
| 
      
 15 
     | 
    
         
            +
             
     | 
| 
      
 16 
     | 
    
         
            +
                  # Calculate mean logarithmic loss.
         
     | 
| 
      
 17 
     | 
    
         
            +
                  # If both y_true and y_pred are array (both shapes are [n_samples]), this method calculates
         
     | 
| 
      
 18 
     | 
    
         
            +
                  # mean logarithmic loss for binary classification.
         
     | 
| 
      
 19 
     | 
    
         
            +
                  #
         
     | 
| 
      
 20 
     | 
    
         
            +
                  # @param y_true [Numo::Int32] (shape: [n_samples]) Ground truth labels.
         
     | 
| 
      
 21 
     | 
    
         
            +
                  # @param y_pred [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted class probability.
         
     | 
| 
      
 22 
     | 
    
         
            +
                  # @param eps [Float] A small value close to zero to avoid outputting infinity in logarithmic calcuation.
         
     | 
| 
      
 23 
     | 
    
         
            +
                  # @return [Float] mean logarithmic loss
         
     | 
| 
      
 24 
     | 
    
         
            +
                  def score(y_true, y_pred, eps = 1e-15)
         
     | 
| 
      
 25 
     | 
    
         
            +
                    check_params_type(Numo::Int32, y_true: y_true)
         
     | 
| 
      
 26 
     | 
    
         
            +
                    check_params_type(Numo::DFloat, y_pred: y_pred)
         
     | 
| 
      
 27 
     | 
    
         
            +
             
     | 
| 
      
 28 
     | 
    
         
            +
                    n_samples, n_classes = y_pred.shape
         
     | 
| 
      
 29 
     | 
    
         
            +
                    clipped_p = y_pred.clip(eps, 1 - eps)
         
     | 
| 
      
 30 
     | 
    
         
            +
             
     | 
| 
      
 31 
     | 
    
         
            +
                    log_loss = if n_classes.nil?
         
     | 
| 
      
 32 
     | 
    
         
            +
                                 negative_label = y_true.to_a.uniq.min
         
     | 
| 
      
 33 
     | 
    
         
            +
                                 bin_y_true = Numo::DFloat.cast(y_true.ne(negative_label))
         
     | 
| 
      
 34 
     | 
    
         
            +
                                 -(bin_y_true * Numo::NMath.log(clipped_p) + (1 - bin_y_true) * Numo::NMath.log(1 - clipped_p))
         
     | 
| 
      
 35 
     | 
    
         
            +
                               else
         
     | 
| 
      
 36 
     | 
    
         
            +
                                 encoder = Rumale::Preprocessing::OneHotEncoder.new
         
     | 
| 
      
 37 
     | 
    
         
            +
                                 encoded_y_true = encoder.fit_transform(y_true)
         
     | 
| 
      
 38 
     | 
    
         
            +
                                 clipped_p /= clipped_p.sum(1).expand_dims(1)
         
     | 
| 
      
 39 
     | 
    
         
            +
                                 -(encoded_y_true * Numo::NMath.log(clipped_p)).sum(1)
         
     | 
| 
      
 40 
     | 
    
         
            +
                               end
         
     | 
| 
      
 41 
     | 
    
         
            +
                    log_loss.sum / n_samples
         
     | 
| 
      
 42 
     | 
    
         
            +
                  end
         
     | 
| 
      
 43 
     | 
    
         
            +
                end
         
     | 
| 
      
 44 
     | 
    
         
            +
              end
         
     | 
| 
      
 45 
     | 
    
         
            +
            end
         
     | 
| 
         @@ -0,0 +1,29 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            # frozen_string_literal: true
         
     | 
| 
      
 2 
     | 
    
         
            +
             
     | 
| 
      
 3 
     | 
    
         
            +
            require 'rumale/base/evaluator'
         
     | 
| 
      
 4 
     | 
    
         
            +
             
     | 
| 
      
 5 
     | 
    
         
            +
            module Rumale
         
     | 
| 
      
 6 
     | 
    
         
            +
              module EvaluationMeasure
         
     | 
| 
      
 7 
     | 
    
         
            +
                # MeanAbsoluteError is a class that calculates the mean absolute error.
         
     | 
| 
      
 8 
     | 
    
         
            +
                #
         
     | 
| 
      
 9 
     | 
    
         
            +
                # @example
         
     | 
| 
      
 10 
     | 
    
         
            +
                #   evaluator = Rumale::EvaluationMeasure::MeanAbsoluteError.new
         
     | 
| 
      
 11 
     | 
    
         
            +
                #   puts evaluator.score(ground_truth, predicted)
         
     | 
| 
      
 12 
     | 
    
         
            +
                class MeanAbsoluteError
         
     | 
| 
      
 13 
     | 
    
         
            +
                  include Base::Evaluator
         
     | 
| 
      
 14 
     | 
    
         
            +
             
     | 
| 
      
 15 
     | 
    
         
            +
                  # Calculate mean absolute error.
         
     | 
| 
      
 16 
     | 
    
         
            +
                  #
         
     | 
| 
      
 17 
     | 
    
         
            +
                  # @param y_true [Numo::DFloat] (shape: [n_samples, n_outputs]) Ground truth target values.
         
     | 
| 
      
 18 
     | 
    
         
            +
                  # @param y_pred [Numo::DFloat] (shape: [n_samples, n_outputs]) Estimated target values.
         
     | 
| 
      
 19 
     | 
    
         
            +
                  # @return [Float] Mean absolute error
         
     | 
| 
      
 20 
     | 
    
         
            +
                  def score(y_true, y_pred)
         
     | 
| 
      
 21 
     | 
    
         
            +
                    check_tvalue_array(y_true)
         
     | 
| 
      
 22 
     | 
    
         
            +
                    check_tvalue_array(y_pred)
         
     | 
| 
      
 23 
     | 
    
         
            +
                    raise ArgumentError, 'Expect to have the same size both y_true and y_pred.' unless y_true.shape == y_pred.shape
         
     | 
| 
      
 24 
     | 
    
         
            +
             
     | 
| 
      
 25 
     | 
    
         
            +
                    (y_true - y_pred).abs.mean
         
     | 
| 
      
 26 
     | 
    
         
            +
                  end
         
     | 
| 
      
 27 
     | 
    
         
            +
                end
         
     | 
| 
      
 28 
     | 
    
         
            +
              end
         
     | 
| 
      
 29 
     | 
    
         
            +
            end
         
     | 
| 
         @@ -0,0 +1,29 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            # frozen_string_literal: true
         
     | 
| 
      
 2 
     | 
    
         
            +
             
     | 
| 
      
 3 
     | 
    
         
            +
            require 'rumale/base/evaluator'
         
     | 
| 
      
 4 
     | 
    
         
            +
             
     | 
| 
      
 5 
     | 
    
         
            +
            module Rumale
         
     | 
| 
      
 6 
     | 
    
         
            +
              module EvaluationMeasure
         
     | 
| 
      
 7 
     | 
    
         
            +
                # MeanSquaredError is a class that calculates the mean squared error.
         
     | 
| 
      
 8 
     | 
    
         
            +
                #
         
     | 
| 
      
 9 
     | 
    
         
            +
                # @example
         
     | 
| 
      
 10 
     | 
    
         
            +
                #   evaluator = Rumale::EvaluationMeasure::MeanSquaredError.new
         
     | 
| 
      
 11 
     | 
    
         
            +
                #   puts evaluator.score(ground_truth, predicted)
         
     | 
| 
      
 12 
     | 
    
         
            +
                class MeanSquaredError
         
     | 
| 
      
 13 
     | 
    
         
            +
                  include Base::Evaluator
         
     | 
| 
      
 14 
     | 
    
         
            +
             
     | 
| 
      
 15 
     | 
    
         
            +
                  # Calculate mean squared error.
         
     | 
| 
      
 16 
     | 
    
         
            +
                  #
         
     | 
| 
      
 17 
     | 
    
         
            +
                  # @param y_true [Numo::DFloat] (shape: [n_samples, n_outputs]) Ground truth target values.
         
     | 
| 
      
 18 
     | 
    
         
            +
                  # @param y_pred [Numo::DFloat] (shape: [n_samples, n_outputs]) Estimated target values.
         
     | 
| 
      
 19 
     | 
    
         
            +
                  # @return [Float] Mean squared error
         
     | 
| 
      
 20 
     | 
    
         
            +
                  def score(y_true, y_pred)
         
     | 
| 
      
 21 
     | 
    
         
            +
                    check_tvalue_array(y_true)
         
     | 
| 
      
 22 
     | 
    
         
            +
                    check_tvalue_array(y_pred)
         
     | 
| 
      
 23 
     | 
    
         
            +
                    raise ArgumentError, 'Expect to have the same size both y_true and y_pred.' unless y_true.shape == y_pred.shape
         
     | 
| 
      
 24 
     | 
    
         
            +
             
     | 
| 
      
 25 
     | 
    
         
            +
                    ((y_true - y_pred)**2).mean
         
     | 
| 
      
 26 
     | 
    
         
            +
                  end
         
     | 
| 
      
 27 
     | 
    
         
            +
                end
         
     | 
| 
      
 28 
     | 
    
         
            +
              end
         
     | 
| 
      
 29 
     | 
    
         
            +
            end
         
     | 
| 
         @@ -0,0 +1,62 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            # frozen_string_literal: true
         
     | 
| 
      
 2 
     | 
    
         
            +
             
     | 
| 
      
 3 
     | 
    
         
            +
            require 'rumale/base/evaluator'
         
     | 
| 
      
 4 
     | 
    
         
            +
             
     | 
| 
      
 5 
     | 
    
         
            +
            module Rumale
         
     | 
| 
      
 6 
     | 
    
         
            +
              module EvaluationMeasure
         
     | 
| 
      
 7 
     | 
    
         
            +
                # NormalizedMutualInformation is a class that calculates the normalized mutual information of cluatering results.
         
     | 
| 
      
 8 
     | 
    
         
            +
                #
         
     | 
| 
      
 9 
     | 
    
         
            +
                # @example
         
     | 
| 
      
 10 
     | 
    
         
            +
                #   evaluator = Rumale::EvaluationMeasure::NormalizedMutualInformation.new
         
     | 
| 
      
 11 
     | 
    
         
            +
                #   puts evaluator.score(ground_truth, predicted)
         
     | 
| 
      
 12 
     | 
    
         
            +
                #
         
     | 
| 
      
 13 
     | 
    
         
            +
                # *Reference*
         
     | 
| 
      
 14 
     | 
    
         
            +
                # - C D. Manning, P. Raghavan, and H. Schutze, "Introduction to Information Retrieval," Cambridge University Press., 2008.
         
     | 
| 
      
 15 
     | 
    
         
            +
                # - N X. Vinh, J. Epps, and J. Bailey, "Information Theoretic Measures for Clusterings Comparison: Variants, Properties, Normalization and Correction for Chance," J. Machine Learning Research, vol. 11, pp. 2837--1854, 2010.
         
     | 
| 
      
 16 
     | 
    
         
            +
                class NormalizedMutualInformation
         
     | 
| 
      
 17 
     | 
    
         
            +
                  include Base::Evaluator
         
     | 
| 
      
 18 
     | 
    
         
            +
             
     | 
| 
      
 19 
     | 
    
         
            +
                  # Calculate noramlzied mutual information
         
     | 
| 
      
 20 
     | 
    
         
            +
                  #
         
     | 
| 
      
 21 
     | 
    
         
            +
                  # @param y_true [Numo::Int32] (shape: [n_samples]) Ground truth labels.
         
     | 
| 
      
 22 
     | 
    
         
            +
                  # @param y_pred [Numo::Int32] (shape: [n_samples]) Predicted cluster labels.
         
     | 
| 
      
 23 
     | 
    
         
            +
                  # @return [Float] Normalized mutual information
         
     | 
| 
      
 24 
     | 
    
         
            +
                  def score(y_true, y_pred)
         
     | 
| 
      
 25 
     | 
    
         
            +
                    check_label_array(y_true)
         
     | 
| 
      
 26 
     | 
    
         
            +
                    check_label_array(y_pred)
         
     | 
| 
      
 27 
     | 
    
         
            +
                    # initiazlie some variables.
         
     | 
| 
      
 28 
     | 
    
         
            +
                    mutual_information = 0.0
         
     | 
| 
      
 29 
     | 
    
         
            +
                    n_samples = y_pred.size
         
     | 
| 
      
 30 
     | 
    
         
            +
                    class_ids = y_true.to_a.uniq
         
     | 
| 
      
 31 
     | 
    
         
            +
                    cluster_ids = y_pred.to_a.uniq
         
     | 
| 
      
 32 
     | 
    
         
            +
                    # calculate entropy.
         
     | 
| 
      
 33 
     | 
    
         
            +
                    class_entropy = -1.0 * class_ids.map do |k|
         
     | 
| 
      
 34 
     | 
    
         
            +
                      ratio = y_true.eq(k).count.fdiv(n_samples)
         
     | 
| 
      
 35 
     | 
    
         
            +
                      ratio * Math.log(ratio)
         
     | 
| 
      
 36 
     | 
    
         
            +
                    end.reduce(:+)
         
     | 
| 
      
 37 
     | 
    
         
            +
                    return 0.0 if class_entropy.zero?
         
     | 
| 
      
 38 
     | 
    
         
            +
                    cluster_entropy = -1.0 * cluster_ids.map do |k|
         
     | 
| 
      
 39 
     | 
    
         
            +
                      ratio = y_pred.eq(k).count.fdiv(n_samples)
         
     | 
| 
      
 40 
     | 
    
         
            +
                      ratio * Math.log(ratio)
         
     | 
| 
      
 41 
     | 
    
         
            +
                    end.reduce(:+)
         
     | 
| 
      
 42 
     | 
    
         
            +
                    return 0.0 if cluster_entropy.zero?
         
     | 
| 
      
 43 
     | 
    
         
            +
                    # calculate mutual information.
         
     | 
| 
      
 44 
     | 
    
         
            +
                    cluster_ids.map do |k|
         
     | 
| 
      
 45 
     | 
    
         
            +
                      pr_sample_ids = y_pred.eq(k).where.to_a
         
     | 
| 
      
 46 
     | 
    
         
            +
                      n_pr_samples = pr_sample_ids.size
         
     | 
| 
      
 47 
     | 
    
         
            +
                      class_ids.map do |j|
         
     | 
| 
      
 48 
     | 
    
         
            +
                        tr_sample_ids = y_true.eq(j).where.to_a
         
     | 
| 
      
 49 
     | 
    
         
            +
                        n_tr_samples = tr_sample_ids.size
         
     | 
| 
      
 50 
     | 
    
         
            +
                        n_intr_samples = (pr_sample_ids & tr_sample_ids).size
         
     | 
| 
      
 51 
     | 
    
         
            +
                        if n_intr_samples.positive?
         
     | 
| 
      
 52 
     | 
    
         
            +
                          mutual_information +=
         
     | 
| 
      
 53 
     | 
    
         
            +
                            n_intr_samples.fdiv(n_samples) * Math.log((n_samples * n_intr_samples).fdiv(n_pr_samples * n_tr_samples))
         
     | 
| 
      
 54 
     | 
    
         
            +
                        end
         
     | 
| 
      
 55 
     | 
    
         
            +
                      end
         
     | 
| 
      
 56 
     | 
    
         
            +
                    end
         
     | 
| 
      
 57 
     | 
    
         
            +
                    # return normalized mutual information.
         
     | 
| 
      
 58 
     | 
    
         
            +
                    mutual_information / Math.sqrt(class_entropy * cluster_entropy)
         
     | 
| 
      
 59 
     | 
    
         
            +
                  end
         
     | 
| 
      
 60 
     | 
    
         
            +
                end
         
     | 
| 
      
 61 
     | 
    
         
            +
              end
         
     | 
| 
      
 62 
     | 
    
         
            +
            end
         
     | 
| 
         @@ -0,0 +1,50 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            # frozen_string_literal: true
         
     | 
| 
      
 2 
     | 
    
         
            +
             
     | 
| 
      
 3 
     | 
    
         
            +
            require 'rumale/base/evaluator'
         
     | 
| 
      
 4 
     | 
    
         
            +
            require 'rumale/evaluation_measure/precision_recall'
         
     | 
| 
      
 5 
     | 
    
         
            +
             
     | 
| 
      
 6 
     | 
    
         
            +
            module Rumale
         
     | 
| 
      
 7 
     | 
    
         
            +
              # This module consists of the classes for model evaluation.
         
     | 
| 
      
 8 
     | 
    
         
            +
              module EvaluationMeasure
         
     | 
| 
      
 9 
     | 
    
         
            +
                # Precision is a class that calculates the preicision of the predicted labels.
         
     | 
| 
      
 10 
     | 
    
         
            +
                #
         
     | 
| 
      
 11 
     | 
    
         
            +
                # @example
         
     | 
| 
      
 12 
     | 
    
         
            +
                #   evaluator = Rumale::EvaluationMeasure::Precision.new
         
     | 
| 
      
 13 
     | 
    
         
            +
                #   puts evaluator.score(ground_truth, predicted)
         
     | 
| 
      
 14 
     | 
    
         
            +
                class Precision
         
     | 
| 
      
 15 
     | 
    
         
            +
                  include Base::Evaluator
         
     | 
| 
      
 16 
     | 
    
         
            +
                  include EvaluationMeasure::PrecisionRecall
         
     | 
| 
      
 17 
     | 
    
         
            +
             
     | 
| 
      
 18 
     | 
    
         
            +
                  # Return the average type for calculation of precision.
         
     | 
| 
      
 19 
     | 
    
         
            +
                  # @return [String] ('binary', 'micro', 'macro')
         
     | 
| 
      
 20 
     | 
    
         
            +
                  attr_reader :average
         
     | 
| 
      
 21 
     | 
    
         
            +
             
     | 
| 
      
 22 
     | 
    
         
            +
                  # Create a new evaluation measure calculater for precision score.
         
     | 
| 
      
 23 
     | 
    
         
            +
                  #
         
     | 
| 
      
 24 
     | 
    
         
            +
                  # @param average [String] The average type ('binary', 'micro', 'macro')
         
     | 
| 
      
 25 
     | 
    
         
            +
                  def initialize(average: 'binary')
         
     | 
| 
      
 26 
     | 
    
         
            +
                    check_params_string(average: average)
         
     | 
| 
      
 27 
     | 
    
         
            +
                    @average = average
         
     | 
| 
      
 28 
     | 
    
         
            +
                  end
         
     | 
| 
      
 29 
     | 
    
         
            +
             
     | 
| 
      
 30 
     | 
    
         
            +
                  # Calculate average precision.
         
     | 
| 
      
 31 
     | 
    
         
            +
                  #
         
     | 
| 
      
 32 
     | 
    
         
            +
                  # @param y_true [Numo::Int32] (shape: [n_samples]) Ground truth labels.
         
     | 
| 
      
 33 
     | 
    
         
            +
                  # @param y_pred [Numo::Int32] (shape: [n_samples]) Predicted labels.
         
     | 
| 
      
 34 
     | 
    
         
            +
                  # @return [Float] Average precision
         
     | 
| 
      
 35 
     | 
    
         
            +
                  def score(y_true, y_pred)
         
     | 
| 
      
 36 
     | 
    
         
            +
                    check_label_array(y_true)
         
     | 
| 
      
 37 
     | 
    
         
            +
                    check_label_array(y_pred)
         
     | 
| 
      
 38 
     | 
    
         
            +
             
     | 
| 
      
 39 
     | 
    
         
            +
                    case @average
         
     | 
| 
      
 40 
     | 
    
         
            +
                    when 'binary'
         
     | 
| 
      
 41 
     | 
    
         
            +
                      precision_each_class(y_true, y_pred).last
         
     | 
| 
      
 42 
     | 
    
         
            +
                    when 'micro'
         
     | 
| 
      
 43 
     | 
    
         
            +
                      micro_average_precision(y_true, y_pred)
         
     | 
| 
      
 44 
     | 
    
         
            +
                    when 'macro'
         
     | 
| 
      
 45 
     | 
    
         
            +
                      macro_average_precision(y_true, y_pred)
         
     | 
| 
      
 46 
     | 
    
         
            +
                    end
         
     | 
| 
      
 47 
     | 
    
         
            +
                  end
         
     | 
| 
      
 48 
     | 
    
         
            +
                end
         
     | 
| 
      
 49 
     | 
    
         
            +
              end
         
     | 
| 
      
 50 
     | 
    
         
            +
            end
         
     | 
| 
         @@ -0,0 +1,91 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            # frozen_string_literal: true
         
     | 
| 
      
 2 
     | 
    
         
            +
             
     | 
| 
      
 3 
     | 
    
         
            +
            require 'rumale/base/evaluator'
         
     | 
| 
      
 4 
     | 
    
         
            +
             
     | 
| 
      
 5 
     | 
    
         
            +
            module Rumale
         
     | 
| 
      
 6 
     | 
    
         
            +
              # This module consists of the classes for model evaluation.
         
     | 
| 
      
 7 
     | 
    
         
            +
              module EvaluationMeasure
         
     | 
| 
      
 8 
     | 
    
         
            +
                # @!visibility private
         
     | 
| 
      
 9 
     | 
    
         
            +
                module PrecisionRecall
         
     | 
| 
      
 10 
     | 
    
         
            +
                  module_function
         
     | 
| 
      
 11 
     | 
    
         
            +
             
     | 
| 
      
 12 
     | 
    
         
            +
                  # @!visibility private
         
     | 
| 
      
 13 
     | 
    
         
            +
                  def precision_each_class(y_true, y_pred)
         
     | 
| 
      
 14 
     | 
    
         
            +
                    y_true.sort.to_a.uniq.map do |label|
         
     | 
| 
      
 15 
     | 
    
         
            +
                      target_positions = y_pred.eq(label)
         
     | 
| 
      
 16 
     | 
    
         
            +
                      next 0.0 if y_pred[target_positions].empty?
         
     | 
| 
      
 17 
     | 
    
         
            +
                      n_true_positives = Numo::Int32.cast(y_true[target_positions].eq(y_pred[target_positions])).sum.to_f
         
     | 
| 
      
 18 
     | 
    
         
            +
                      n_false_positives = Numo::Int32.cast(y_true[target_positions].ne(y_pred[target_positions])).sum.to_f
         
     | 
| 
      
 19 
     | 
    
         
            +
                      n_true_positives / (n_true_positives + n_false_positives)
         
     | 
| 
      
 20 
     | 
    
         
            +
                    end
         
     | 
| 
      
 21 
     | 
    
         
            +
                  end
         
     | 
| 
      
 22 
     | 
    
         
            +
             
     | 
| 
      
 23 
     | 
    
         
            +
                  # @!visibility private
         
     | 
| 
      
 24 
     | 
    
         
            +
                  def recall_each_class(y_true, y_pred)
         
     | 
| 
      
 25 
     | 
    
         
            +
                    y_true.sort.to_a.uniq.map do |label|
         
     | 
| 
      
 26 
     | 
    
         
            +
                      target_positions = y_true.eq(label)
         
     | 
| 
      
 27 
     | 
    
         
            +
                      next 0.0 if y_pred[target_positions].empty?
         
     | 
| 
      
 28 
     | 
    
         
            +
                      n_true_positives = Numo::Int32.cast(y_true[target_positions].eq(y_pred[target_positions])).sum.to_f
         
     | 
| 
      
 29 
     | 
    
         
            +
                      n_false_negatives = Numo::Int32.cast(y_true[target_positions].ne(y_pred[target_positions])).sum.to_f
         
     | 
| 
      
 30 
     | 
    
         
            +
                      n_true_positives / (n_true_positives + n_false_negatives)
         
     | 
| 
      
 31 
     | 
    
         
            +
                    end
         
     | 
| 
      
 32 
     | 
    
         
            +
                  end
         
     | 
| 
      
 33 
     | 
    
         
            +
             
     | 
| 
      
 34 
     | 
    
         
            +
                  # @!visibility private
         
     | 
| 
      
 35 
     | 
    
         
            +
                  def f_score_each_class(y_true, y_pred)
         
     | 
| 
      
 36 
     | 
    
         
            +
                    precision_each_class(y_true, y_pred).zip(recall_each_class(y_true, y_pred)).map do |p, r|
         
     | 
| 
      
 37 
     | 
    
         
            +
                      next 0.0 if p.zero? && r.zero?
         
     | 
| 
      
 38 
     | 
    
         
            +
                      (2.0 * p * r) / (p + r)
         
     | 
| 
      
 39 
     | 
    
         
            +
                    end
         
     | 
| 
      
 40 
     | 
    
         
            +
                  end
         
     | 
| 
      
 41 
     | 
    
         
            +
             
     | 
| 
      
 42 
     | 
    
         
            +
                  # @!visibility private
         
     | 
| 
      
 43 
     | 
    
         
            +
                  def micro_average_precision(y_true, y_pred)
         
     | 
| 
      
 44 
     | 
    
         
            +
                    evaluated_values = y_true.sort.to_a.uniq.map do |label|
         
     | 
| 
      
 45 
     | 
    
         
            +
                      target_positions = y_pred.eq(label)
         
     | 
| 
      
 46 
     | 
    
         
            +
                      next [0.0, 0.0] if y_pred[target_positions].empty?
         
     | 
| 
      
 47 
     | 
    
         
            +
                      n_true_positives = Numo::Int32.cast(y_true[target_positions].eq(y_pred[target_positions])).sum.to_f
         
     | 
| 
      
 48 
     | 
    
         
            +
                      n_false_positives = Numo::Int32.cast(y_true[target_positions].ne(y_pred[target_positions])).sum.to_f
         
     | 
| 
      
 49 
     | 
    
         
            +
                      [n_true_positives, n_true_positives + n_false_positives]
         
     | 
| 
      
 50 
     | 
    
         
            +
                    end
         
     | 
| 
      
 51 
     | 
    
         
            +
                    res = evaluated_values.transpose.map { |v| v.inject(:+) }
         
     | 
| 
      
 52 
     | 
    
         
            +
                    res.first / res.last
         
     | 
| 
      
 53 
     | 
    
         
            +
                  end
         
     | 
| 
      
 54 
     | 
    
         
            +
             
     | 
| 
      
 55 
     | 
    
         
            +
                  # @!visibility private
         
     | 
| 
      
 56 
     | 
    
         
            +
                  def micro_average_recall(y_true, y_pred)
         
     | 
| 
      
 57 
     | 
    
         
            +
                    evaluated_values = y_true.sort.to_a.uniq.map do |label|
         
     | 
| 
      
 58 
     | 
    
         
            +
                      target_positions = y_true.eq(label)
         
     | 
| 
      
 59 
     | 
    
         
            +
                      next 0.0 if y_pred[target_positions].empty?
         
     | 
| 
      
 60 
     | 
    
         
            +
                      n_true_positives = Numo::Int32.cast(y_true[target_positions].eq(y_pred[target_positions])).sum.to_f
         
     | 
| 
      
 61 
     | 
    
         
            +
                      n_false_negatives = Numo::Int32.cast(y_true[target_positions].ne(y_pred[target_positions])).sum.to_f
         
     | 
| 
      
 62 
     | 
    
         
            +
                      [n_true_positives, n_true_positives + n_false_negatives]
         
     | 
| 
      
 63 
     | 
    
         
            +
                    end
         
     | 
| 
      
 64 
     | 
    
         
            +
                    res = evaluated_values.transpose.map { |v| v.inject(:+) }
         
     | 
| 
      
 65 
     | 
    
         
            +
                    res.first / res.last
         
     | 
| 
      
 66 
     | 
    
         
            +
                  end
         
     | 
| 
      
 67 
     | 
    
         
            +
             
     | 
| 
      
 68 
     | 
    
         
            +
                  # @!visibility private
         
     | 
| 
      
 69 
     | 
    
         
            +
                  def micro_average_f_score(y_true, y_pred)
         
     | 
| 
      
 70 
     | 
    
         
            +
                    p = micro_average_precision(y_true, y_pred)
         
     | 
| 
      
 71 
     | 
    
         
            +
                    r = micro_average_recall(y_true, y_pred)
         
     | 
| 
      
 72 
     | 
    
         
            +
                    (2.0 * p * r) / (p + r)
         
     | 
| 
      
 73 
     | 
    
         
            +
                  end
         
     | 
| 
      
 74 
     | 
    
         
            +
             
     | 
| 
      
 75 
     | 
    
         
            +
                  # @!visibility private
         
     | 
| 
      
 76 
     | 
    
         
            +
                  def macro_average_precision(y_true, y_pred)
         
     | 
| 
      
 77 
     | 
    
         
            +
                    precision_each_class(y_true, y_pred).inject(:+) / y_true.to_a.uniq.size
         
     | 
| 
      
 78 
     | 
    
         
            +
                  end
         
     | 
| 
      
 79 
     | 
    
         
            +
             
     | 
| 
      
 80 
     | 
    
         
            +
                  # @!visibility private
         
     | 
| 
      
 81 
     | 
    
         
            +
                  def macro_average_recall(y_true, y_pred)
         
     | 
| 
      
 82 
     | 
    
         
            +
                    recall_each_class(y_true, y_pred).inject(:+) / y_true.to_a.uniq.size
         
     | 
| 
      
 83 
     | 
    
         
            +
                  end
         
     | 
| 
      
 84 
     | 
    
         
            +
             
     | 
| 
      
 85 
     | 
    
         
            +
                  # @!visibility private
         
     | 
| 
      
 86 
     | 
    
         
            +
                  def macro_average_f_score(y_true, y_pred)
         
     | 
| 
      
 87 
     | 
    
         
            +
                    f_score_each_class(y_true, y_pred).inject(:+) / y_true.to_a.uniq.size
         
     | 
| 
      
 88 
     | 
    
         
            +
                  end
         
     | 
| 
      
 89 
     | 
    
         
            +
                end
         
     | 
| 
      
 90 
     | 
    
         
            +
              end
         
     | 
| 
      
 91 
     | 
    
         
            +
            end
         
     | 
| 
         @@ -0,0 +1,40 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            # frozen_string_literal: true
         
     | 
| 
      
 2 
     | 
    
         
            +
             
     | 
| 
      
 3 
     | 
    
         
            +
            require 'rumale/base/evaluator'
         
     | 
| 
      
 4 
     | 
    
         
            +
             
     | 
| 
      
 5 
     | 
    
         
            +
            module Rumale
         
     | 
| 
      
 6 
     | 
    
         
            +
              module EvaluationMeasure
         
     | 
| 
      
 7 
     | 
    
         
            +
                # Purity is a class that calculates the purity of cluatering results.
         
     | 
| 
      
 8 
     | 
    
         
            +
                #
         
     | 
| 
      
 9 
     | 
    
         
            +
                # @example
         
     | 
| 
      
 10 
     | 
    
         
            +
                #   evaluator = Rumale::EvaluationMeasure::Purity.new
         
     | 
| 
      
 11 
     | 
    
         
            +
                #   puts evaluator.score(ground_truth, predicted)
         
     | 
| 
      
 12 
     | 
    
         
            +
                #
         
     | 
| 
      
 13 
     | 
    
         
            +
                # *Reference*
         
     | 
| 
      
 14 
     | 
    
         
            +
                # - C D. Manning, P. Raghavan, and H. Schutze, "Introduction to Information Retrieval," Cambridge University Press., 2008.
         
     | 
| 
      
 15 
     | 
    
         
            +
                class Purity
         
     | 
| 
      
 16 
     | 
    
         
            +
                  include Base::Evaluator
         
     | 
| 
      
 17 
     | 
    
         
            +
             
     | 
| 
      
 18 
     | 
    
         
            +
                  # Calculate purity
         
     | 
| 
      
 19 
     | 
    
         
            +
                  #
         
     | 
| 
      
 20 
     | 
    
         
            +
                  # @param y_true [Numo::Int32] (shape: [n_samples]) Ground truth labels.
         
     | 
| 
      
 21 
     | 
    
         
            +
                  # @param y_pred [Numo::Int32] (shape: [n_samples]) Predicted cluster labels.
         
     | 
| 
      
 22 
     | 
    
         
            +
                  # @return [Float] Purity
         
     | 
| 
      
 23 
     | 
    
         
            +
                  def score(y_true, y_pred)
         
     | 
| 
      
 24 
     | 
    
         
            +
                    check_label_array(y_true)
         
     | 
| 
      
 25 
     | 
    
         
            +
                    check_label_array(y_pred)
         
     | 
| 
      
 26 
     | 
    
         
            +
                    # initiazlie some variables.
         
     | 
| 
      
 27 
     | 
    
         
            +
                    purity = 0
         
     | 
| 
      
 28 
     | 
    
         
            +
                    n_samples = y_pred.size
         
     | 
| 
      
 29 
     | 
    
         
            +
                    class_ids = y_true.to_a.uniq
         
     | 
| 
      
 30 
     | 
    
         
            +
                    cluster_ids = y_pred.to_a.uniq
         
     | 
| 
      
 31 
     | 
    
         
            +
                    # calculate purity.
         
     | 
| 
      
 32 
     | 
    
         
            +
                    cluster_ids.each do |k|
         
     | 
| 
      
 33 
     | 
    
         
            +
                      pr_sample_ids = y_pred.eq(k).where.to_a
         
     | 
| 
      
 34 
     | 
    
         
            +
                      purity += class_ids.map { |j| (pr_sample_ids & y_true.eq(j).where.to_a).size }.max
         
     | 
| 
      
 35 
     | 
    
         
            +
                    end
         
     | 
| 
      
 36 
     | 
    
         
            +
                    purity.fdiv(n_samples)
         
     | 
| 
      
 37 
     | 
    
         
            +
                  end
         
     | 
| 
      
 38 
     | 
    
         
            +
                end
         
     | 
| 
      
 39 
     | 
    
         
            +
              end
         
     | 
| 
      
 40 
     | 
    
         
            +
            end
         
     | 
| 
         @@ -0,0 +1,43 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            # frozen_string_literal: true
         
     | 
| 
      
 2 
     | 
    
         
            +
             
     | 
| 
      
 3 
     | 
    
         
            +
            require 'rumale/base/evaluator'
         
     | 
| 
      
 4 
     | 
    
         
            +
            require 'rumale/evaluation_measure/precision_recall'
         
     | 
| 
      
 5 
     | 
    
         
            +
             
     | 
| 
      
 6 
     | 
    
         
            +
            module Rumale
         
     | 
| 
      
 7 
     | 
    
         
            +
              module EvaluationMeasure
         
     | 
| 
      
 8 
     | 
    
         
            +
                # R2Score is a class that calculates the coefficient of determination for the predicted values.
         
     | 
| 
      
 9 
     | 
    
         
            +
                #
         
     | 
| 
      
 10 
     | 
    
         
            +
                # @example
         
     | 
| 
      
 11 
     | 
    
         
            +
                #   evaluator = Rumale::EvaluationMeasure::R2Score.new
         
     | 
| 
      
 12 
     | 
    
         
            +
                #   puts evaluator.score(ground_truth, predicted)
         
     | 
| 
      
 13 
     | 
    
         
            +
                class R2Score
         
     | 
| 
      
 14 
     | 
    
         
            +
                  include Base::Evaluator
         
     | 
| 
      
 15 
     | 
    
         
            +
             
     | 
| 
      
 16 
     | 
    
         
            +
                  # Create a new evaluation measure calculater for coefficient of determination.
         
     | 
| 
      
 17 
     | 
    
         
            +
                  def initialize; end
         
     | 
| 
      
 18 
     | 
    
         
            +
             
     | 
| 
      
 19 
     | 
    
         
            +
                  # Calculate the coefficient of determination.
         
     | 
| 
      
 20 
     | 
    
         
            +
                  #
         
     | 
| 
      
 21 
     | 
    
         
            +
                  # @param y_true [Numo::DFloat] (shape: [n_samples, n_outputs]) Ground truth target values.
         
     | 
| 
      
 22 
     | 
    
         
            +
                  # @param y_pred [Numo::DFloat] (shape: [n_samples, n_outputs]) Estimated taget values.
         
     | 
| 
      
 23 
     | 
    
         
            +
                  # @return [Float] Coefficient of determination
         
     | 
| 
      
 24 
     | 
    
         
            +
                  def score(y_true, y_pred)
         
     | 
| 
      
 25 
     | 
    
         
            +
                    check_tvalue_array(y_true)
         
     | 
| 
      
 26 
     | 
    
         
            +
                    check_tvalue_array(y_pred)
         
     | 
| 
      
 27 
     | 
    
         
            +
                    raise ArgumentError, 'Expect to have the same size both y_true and y_pred.' unless y_true.shape == y_pred.shape
         
     | 
| 
      
 28 
     | 
    
         
            +
             
     | 
| 
      
 29 
     | 
    
         
            +
                    n_samples, n_outputs = y_true.shape
         
     | 
| 
      
 30 
     | 
    
         
            +
                    numerator = ((y_true - y_pred)**2).sum(0)
         
     | 
| 
      
 31 
     | 
    
         
            +
                    yt_mean = y_true.sum(0) / n_samples
         
     | 
| 
      
 32 
     | 
    
         
            +
                    denominator = ((y_true - yt_mean)**2).sum(0)
         
     | 
| 
      
 33 
     | 
    
         
            +
                    if n_outputs.nil?
         
     | 
| 
      
 34 
     | 
    
         
            +
                      denominator.zero? ? 0.0 : 1.0 - numerator / denominator
         
     | 
| 
      
 35 
     | 
    
         
            +
                    else
         
     | 
| 
      
 36 
     | 
    
         
            +
                      scores = 1 - numerator / denominator
         
     | 
| 
      
 37 
     | 
    
         
            +
                      scores[denominator.eq(0)] = 0.0
         
     | 
| 
      
 38 
     | 
    
         
            +
                      scores.sum / scores.size
         
     | 
| 
      
 39 
     | 
    
         
            +
                    end
         
     | 
| 
      
 40 
     | 
    
         
            +
                  end
         
     | 
| 
      
 41 
     | 
    
         
            +
                end
         
     | 
| 
      
 42 
     | 
    
         
            +
              end
         
     | 
| 
      
 43 
     | 
    
         
            +
            end
         
     |