svmkit 0.7.3 → 0.8.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.gitignore +0 -9
- data/.rspec +1 -0
- data/.travis.yml +4 -12
- data/LICENSE.txt +1 -1
- data/README.md +11 -13
- data/lib/svmkit.rb +3 -66
- data/svmkit.gemspec +12 -7
- metadata +16 -81
- data/.coveralls.yml +0 -1
- data/.rubocop.yml +0 -47
- data/.rubocop_todo.yml +0 -58
- data/HISTORY.md +0 -168
- data/lib/svmkit/base/base_estimator.rb +0 -13
- data/lib/svmkit/base/classifier.rb +0 -34
- data/lib/svmkit/base/cluster_analyzer.rb +0 -29
- data/lib/svmkit/base/evaluator.rb +0 -13
- data/lib/svmkit/base/regressor.rb +0 -34
- data/lib/svmkit/base/splitter.rb +0 -17
- data/lib/svmkit/base/transformer.rb +0 -18
- data/lib/svmkit/clustering/dbscan.rb +0 -127
- data/lib/svmkit/clustering/k_means.rb +0 -140
- data/lib/svmkit/dataset.rb +0 -109
- data/lib/svmkit/decomposition/nmf.rb +0 -147
- data/lib/svmkit/decomposition/pca.rb +0 -150
- data/lib/svmkit/ensemble/ada_boost_classifier.rb +0 -198
- data/lib/svmkit/ensemble/ada_boost_regressor.rb +0 -180
- data/lib/svmkit/ensemble/random_forest_classifier.rb +0 -182
- data/lib/svmkit/ensemble/random_forest_regressor.rb +0 -143
- data/lib/svmkit/evaluation_measure/accuracy.rb +0 -30
- data/lib/svmkit/evaluation_measure/f_score.rb +0 -51
- data/lib/svmkit/evaluation_measure/log_loss.rb +0 -46
- data/lib/svmkit/evaluation_measure/mean_absolute_error.rb +0 -30
- data/lib/svmkit/evaluation_measure/mean_squared_error.rb +0 -30
- data/lib/svmkit/evaluation_measure/normalized_mutual_information.rb +0 -63
- data/lib/svmkit/evaluation_measure/precision.rb +0 -51
- data/lib/svmkit/evaluation_measure/precision_recall.rb +0 -91
- data/lib/svmkit/evaluation_measure/purity.rb +0 -41
- data/lib/svmkit/evaluation_measure/r2_score.rb +0 -44
- data/lib/svmkit/evaluation_measure/recall.rb +0 -51
- data/lib/svmkit/kernel_approximation/rbf.rb +0 -136
- data/lib/svmkit/kernel_machine/kernel_svc.rb +0 -194
- data/lib/svmkit/linear_model/lasso.rb +0 -138
- data/lib/svmkit/linear_model/linear_regression.rb +0 -112
- data/lib/svmkit/linear_model/logistic_regression.rb +0 -161
- data/lib/svmkit/linear_model/ridge.rb +0 -112
- data/lib/svmkit/linear_model/sgd_linear_estimator.rb +0 -89
- data/lib/svmkit/linear_model/svc.rb +0 -184
- data/lib/svmkit/linear_model/svr.rb +0 -123
- data/lib/svmkit/model_selection/cross_validation.rb +0 -121
- data/lib/svmkit/model_selection/grid_search_cv.rb +0 -247
- data/lib/svmkit/model_selection/k_fold.rb +0 -77
- data/lib/svmkit/model_selection/stratified_k_fold.rb +0 -95
- data/lib/svmkit/multiclass/one_vs_rest_classifier.rb +0 -101
- data/lib/svmkit/naive_bayes/naive_bayes.rb +0 -316
- data/lib/svmkit/nearest_neighbors/k_neighbors_classifier.rb +0 -112
- data/lib/svmkit/nearest_neighbors/k_neighbors_regressor.rb +0 -94
- data/lib/svmkit/optimizer/nadam.rb +0 -90
- data/lib/svmkit/optimizer/rmsprop.rb +0 -69
- data/lib/svmkit/optimizer/sgd.rb +0 -65
- data/lib/svmkit/optimizer/yellow_fin.rb +0 -144
- data/lib/svmkit/pairwise_metric.rb +0 -91
- data/lib/svmkit/pipeline/pipeline.rb +0 -197
- data/lib/svmkit/polynomial_model/factorization_machine_classifier.rb +0 -262
- data/lib/svmkit/polynomial_model/factorization_machine_regressor.rb +0 -194
- data/lib/svmkit/preprocessing/l2_normalizer.rb +0 -63
- data/lib/svmkit/preprocessing/label_encoder.rb +0 -95
- data/lib/svmkit/preprocessing/min_max_scaler.rb +0 -93
- data/lib/svmkit/preprocessing/one_hot_encoder.rb +0 -99
- data/lib/svmkit/preprocessing/standard_scaler.rb +0 -87
- data/lib/svmkit/probabilistic_output.rb +0 -112
- data/lib/svmkit/tree/decision_tree_classifier.rb +0 -276
- data/lib/svmkit/tree/decision_tree_regressor.rb +0 -251
- data/lib/svmkit/tree/node.rb +0 -70
- data/lib/svmkit/utils.rb +0 -22
- data/lib/svmkit/validation.rb +0 -79
- data/lib/svmkit/values.rb +0 -13
- data/lib/svmkit/version.rb +0 -7
@@ -1,143 +0,0 @@
|
|
1
|
-
# frozen_string_literal: true
|
2
|
-
|
3
|
-
require 'svmkit/validation'
|
4
|
-
require 'svmkit/values'
|
5
|
-
require 'svmkit/base/base_estimator'
|
6
|
-
require 'svmkit/base/regressor'
|
7
|
-
require 'svmkit/tree/decision_tree_regressor'
|
8
|
-
|
9
|
-
module SVMKit
|
10
|
-
module Ensemble
|
11
|
-
# RandomForestRegressor is a class that implements random forest for regression
|
12
|
-
#
|
13
|
-
# @example
|
14
|
-
# estimator =
|
15
|
-
# SVMKit::Ensemble::RandomForestRegressor.new(
|
16
|
-
# n_estimators: 10, criterion: 'mse', max_depth: 3, max_leaf_nodes: 10, min_samples_leaf: 5, random_seed: 1)
|
17
|
-
# estimator.fit(training_samples, traininig_values)
|
18
|
-
# results = estimator.predict(testing_samples)
|
19
|
-
#
|
20
|
-
class RandomForestRegressor
|
21
|
-
include Base::BaseEstimator
|
22
|
-
include Base::Regressor
|
23
|
-
include Validation
|
24
|
-
|
25
|
-
# Return the set of estimators.
|
26
|
-
# @return [Array<DecisionTreeRegressor>]
|
27
|
-
attr_reader :estimators
|
28
|
-
|
29
|
-
# Return the importance for each feature.
|
30
|
-
# @return [Numo::DFloat] (size: n_features)
|
31
|
-
attr_reader :feature_importances
|
32
|
-
|
33
|
-
# Return the random generator for random selection of feature index.
|
34
|
-
# @return [Random]
|
35
|
-
attr_reader :rng
|
36
|
-
|
37
|
-
# Create a new regressor with random forest.
|
38
|
-
#
|
39
|
-
# @param n_estimators [Integer] The numeber of decision trees for contructing random forest.
|
40
|
-
# @param criterion [String] The function to evalue spliting point. Supported criteria are 'gini' and 'entropy'.
|
41
|
-
# @param max_depth [Integer] The maximum depth of the tree.
|
42
|
-
# If nil is given, decision tree grows without concern for depth.
|
43
|
-
# @param max_leaf_nodes [Integer] The maximum number of leaves on decision tree.
|
44
|
-
# If nil is given, number of leaves is not limited.
|
45
|
-
# @param min_samples_leaf [Integer] The minimum number of samples at a leaf node.
|
46
|
-
# @param max_features [Integer] The number of features to consider when searching optimal split point.
|
47
|
-
# If nil is given, split process considers all features.
|
48
|
-
# @param random_seed [Integer] The seed value using to initialize the random generator.
|
49
|
-
# It is used to randomly determine the order of features when deciding spliting point.
|
50
|
-
def initialize(n_estimators: 10,
|
51
|
-
criterion: 'mse', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1,
|
52
|
-
max_features: nil, random_seed: nil)
|
53
|
-
check_params_type_or_nil(Integer, max_depth: max_depth, max_leaf_nodes: max_leaf_nodes,
|
54
|
-
max_features: max_features, random_seed: random_seed)
|
55
|
-
check_params_integer(n_estimators: n_estimators, min_samples_leaf: min_samples_leaf)
|
56
|
-
check_params_string(criterion: criterion)
|
57
|
-
check_params_positive(n_estimators: n_estimators, max_depth: max_depth,
|
58
|
-
max_leaf_nodes: max_leaf_nodes, min_samples_leaf: min_samples_leaf,
|
59
|
-
max_features: max_features)
|
60
|
-
@params = {}
|
61
|
-
@params[:n_estimators] = n_estimators
|
62
|
-
@params[:criterion] = criterion
|
63
|
-
@params[:max_depth] = max_depth
|
64
|
-
@params[:max_leaf_nodes] = max_leaf_nodes
|
65
|
-
@params[:min_samples_leaf] = min_samples_leaf
|
66
|
-
@params[:max_features] = max_features
|
67
|
-
@params[:random_seed] = random_seed
|
68
|
-
@params[:random_seed] ||= srand
|
69
|
-
@estimators = nil
|
70
|
-
@feature_importances = nil
|
71
|
-
@rng = Random.new(@params[:random_seed])
|
72
|
-
end
|
73
|
-
|
74
|
-
# Fit the model with given training data.
|
75
|
-
#
|
76
|
-
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
|
77
|
-
# @param y [Numo::DFloat] (shape: [n_samples, n_outputs]) The target values to be used for fitting the model.
|
78
|
-
# @return [RandomForestRegressor] The learned regressor itself.
|
79
|
-
def fit(x, y)
|
80
|
-
check_sample_array(x)
|
81
|
-
check_tvalue_array(y)
|
82
|
-
check_sample_tvalue_size(x, y)
|
83
|
-
# Initialize some variables.
|
84
|
-
n_samples, n_features = x.shape
|
85
|
-
@params[:max_features] = Math.sqrt(n_features).to_i unless @params[:max_features].is_a?(Integer)
|
86
|
-
@params[:max_features] = [[1, @params[:max_features]].max, n_features].min
|
87
|
-
@feature_importances = Numo::DFloat.zeros(n_features)
|
88
|
-
single_target = y.shape[1].nil?
|
89
|
-
# Construct forest.
|
90
|
-
@estimators = Array.new(@params[:n_estimators]) do
|
91
|
-
tree = Tree::DecisionTreeRegressor.new(
|
92
|
-
criterion: @params[:criterion], max_depth: @params[:max_depth],
|
93
|
-
max_leaf_nodes: @params[:max_leaf_nodes], min_samples_leaf: @params[:min_samples_leaf],
|
94
|
-
max_features: @params[:max_features], random_seed: @rng.rand(SVMKit::Values.int_max)
|
95
|
-
)
|
96
|
-
bootstrap_ids = Array.new(n_samples) { @rng.rand(0...n_samples) }
|
97
|
-
tree.fit(x[bootstrap_ids, true], single_target ? y[bootstrap_ids] : y[bootstrap_ids, true])
|
98
|
-
@feature_importances += tree.feature_importances
|
99
|
-
tree
|
100
|
-
end
|
101
|
-
@feature_importances /= @feature_importances.sum
|
102
|
-
self
|
103
|
-
end
|
104
|
-
|
105
|
-
# Predict values for samples.
|
106
|
-
#
|
107
|
-
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the values.
|
108
|
-
# @return [Numo::DFloat] (shape: [n_samples, n_outputs]) Predicted value per sample.
|
109
|
-
def predict(x)
|
110
|
-
check_sample_array(x)
|
111
|
-
@estimators.map { |est| est.predict(x) }.reduce(&:+) / @params[:n_estimators]
|
112
|
-
end
|
113
|
-
|
114
|
-
# Return the index of the leaf that each sample reached.
|
115
|
-
#
|
116
|
-
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to assign each leaf.
|
117
|
-
# @return [Numo::Int32] (shape: [n_samples, n_estimators]) Leaf index for sample.
|
118
|
-
def apply(x)
|
119
|
-
SVMKit::Validation.check_sample_array(x)
|
120
|
-
Numo::Int32[*Array.new(@params[:n_estimators]) { |n| @estimators[n].apply(x) }].transpose
|
121
|
-
end
|
122
|
-
|
123
|
-
# Dump marshal data.
|
124
|
-
# @return [Hash] The marshal data about RandomForestRegressor.
|
125
|
-
def marshal_dump
|
126
|
-
{ params: @params,
|
127
|
-
estimators: @estimators,
|
128
|
-
feature_importances: @feature_importances,
|
129
|
-
rng: @rng }
|
130
|
-
end
|
131
|
-
|
132
|
-
# Load marshal data.
|
133
|
-
# @return [nil]
|
134
|
-
def marshal_load(obj)
|
135
|
-
@params = obj[:params]
|
136
|
-
@estimators = obj[:estimators]
|
137
|
-
@feature_importances = obj[:feature_importances]
|
138
|
-
@rng = obj[:rng]
|
139
|
-
nil
|
140
|
-
end
|
141
|
-
end
|
142
|
-
end
|
143
|
-
end
|
@@ -1,30 +0,0 @@
|
|
1
|
-
# frozen_string_literal: true
|
2
|
-
|
3
|
-
require 'svmkit/validation'
|
4
|
-
require 'svmkit/base/evaluator'
|
5
|
-
|
6
|
-
module SVMKit
|
7
|
-
# This module consists of the classes for model evaluation.
|
8
|
-
module EvaluationMeasure
|
9
|
-
# Accuracy is a class that calculates the accuracy of classifier from the predicted labels.
|
10
|
-
#
|
11
|
-
# @example
|
12
|
-
# evaluator = SVMKit::EvaluationMeasure::Accuracy.new
|
13
|
-
# puts evaluator.score(ground_truth, predicted)
|
14
|
-
class Accuracy
|
15
|
-
include Base::Evaluator
|
16
|
-
|
17
|
-
# Calculate mean accuracy.
|
18
|
-
#
|
19
|
-
# @param y_true [Numo::Int32] (shape: [n_samples]) Ground truth labels.
|
20
|
-
# @param y_pred [Numo::Int32] (shape: [n_samples]) Predicted labels.
|
21
|
-
# @return [Float] Mean accuracy
|
22
|
-
def score(y_true, y_pred)
|
23
|
-
SVMKit::Validation.check_label_array(y_true)
|
24
|
-
SVMKit::Validation.check_label_array(y_pred)
|
25
|
-
|
26
|
-
(y_true.to_a.map.with_index { |label, n| label == y_pred[n] ? 1 : 0 }).inject(:+) / y_true.size.to_f
|
27
|
-
end
|
28
|
-
end
|
29
|
-
end
|
30
|
-
end
|
@@ -1,51 +0,0 @@
|
|
1
|
-
# frozen_string_literal: true
|
2
|
-
|
3
|
-
require 'svmkit/validation'
|
4
|
-
require 'svmkit/base/evaluator'
|
5
|
-
require 'svmkit/evaluation_measure/precision_recall'
|
6
|
-
|
7
|
-
module SVMKit
|
8
|
-
# This module consists of the classes for model evaluation.
|
9
|
-
module EvaluationMeasure
|
10
|
-
# FScore is a class that calculates the F1-score of the predicted labels.
|
11
|
-
#
|
12
|
-
# @example
|
13
|
-
# evaluator = SVMKit::EvaluationMeasure::FScore.new
|
14
|
-
# puts evaluator.score(ground_truth, predicted)
|
15
|
-
class FScore
|
16
|
-
include Base::Evaluator
|
17
|
-
include EvaluationMeasure::PrecisionRecall
|
18
|
-
|
19
|
-
# Return the average type for calculation of F1-score.
|
20
|
-
# @return [String] ('binary', 'micro', 'macro')
|
21
|
-
attr_reader :average
|
22
|
-
|
23
|
-
# Create a new evaluation measure calculater for F1-score.
|
24
|
-
#
|
25
|
-
# @param average [String] The average type ('binary', 'micro', 'macro')
|
26
|
-
def initialize(average: 'binary')
|
27
|
-
SVMKit::Validation.check_params_string(average: average)
|
28
|
-
@average = average
|
29
|
-
end
|
30
|
-
|
31
|
-
# Calculate average F1-score
|
32
|
-
#
|
33
|
-
# @param y_true [Numo::Int32] (shape: [n_samples]) Ground truth labels.
|
34
|
-
# @param y_pred [Numo::Int32] (shape: [n_samples]) Predicted labels.
|
35
|
-
# @return [Float] Average F1-score
|
36
|
-
def score(y_true, y_pred)
|
37
|
-
SVMKit::Validation.check_label_array(y_true)
|
38
|
-
SVMKit::Validation.check_label_array(y_pred)
|
39
|
-
|
40
|
-
case @average
|
41
|
-
when 'binary'
|
42
|
-
f_score_each_class(y_true, y_pred).last
|
43
|
-
when 'micro'
|
44
|
-
micro_average_f_score(y_true, y_pred)
|
45
|
-
when 'macro'
|
46
|
-
macro_average_f_score(y_true, y_pred)
|
47
|
-
end
|
48
|
-
end
|
49
|
-
end
|
50
|
-
end
|
51
|
-
end
|
@@ -1,46 +0,0 @@
|
|
1
|
-
# frozen_string_literal: true
|
2
|
-
|
3
|
-
require 'svmkit/validation'
|
4
|
-
require 'svmkit/base/evaluator'
|
5
|
-
require 'svmkit/preprocessing/one_hot_encoder'
|
6
|
-
|
7
|
-
module SVMKit
|
8
|
-
module EvaluationMeasure
|
9
|
-
# LogLoss is a class that calculates the logarithmic loss of predicted class probability.
|
10
|
-
#
|
11
|
-
# @example
|
12
|
-
# evaluator = SVMKit::EvaluationMeasure::LogLoss.new
|
13
|
-
# puts evaluator.score(ground_truth, predicted)
|
14
|
-
class LogLoss
|
15
|
-
include Base::Evaluator
|
16
|
-
|
17
|
-
# Calculate mean logarithmic loss.
|
18
|
-
# If both y_true and y_pred are array (both shapes are [n_samples]), this method calculates
|
19
|
-
# mean logarithmic loss for binary classification.
|
20
|
-
#
|
21
|
-
# @param y_true [Numo::Int32] (shape: [n_samples]) Ground truth labels.
|
22
|
-
# @param y_pred [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted class probability.
|
23
|
-
# @param eps [Float] A small value close to zero to avoid outputting infinity in logarithmic calcuation.
|
24
|
-
# @return [Float] mean logarithmic loss
|
25
|
-
def score(y_true, y_pred, eps = 1e-15)
|
26
|
-
SVMKit::Validation.check_params_type(Numo::Int32, y_true: y_true)
|
27
|
-
SVMKit::Validation.check_params_type(Numo::DFloat, y_pred: y_pred)
|
28
|
-
|
29
|
-
n_samples, n_classes = y_pred.shape
|
30
|
-
clipped_p = y_pred.clip(eps, 1 - eps)
|
31
|
-
|
32
|
-
log_loss = if n_classes.nil?
|
33
|
-
negative_label = y_true.to_a.uniq.min
|
34
|
-
bin_y_true = Numo::DFloat.cast(y_true.ne(negative_label))
|
35
|
-
-(bin_y_true * Numo::NMath.log(clipped_p) + (1 - bin_y_true) * Numo::NMath.log(1 - clipped_p))
|
36
|
-
else
|
37
|
-
encoder = SVMKit::Preprocessing::OneHotEncoder.new
|
38
|
-
encoded_y_true = encoder.fit_transform(y_true)
|
39
|
-
clipped_p /= clipped_p.sum(1).expand_dims(1)
|
40
|
-
-(encoded_y_true * Numo::NMath.log(clipped_p)).sum(1)
|
41
|
-
end
|
42
|
-
log_loss.sum / n_samples
|
43
|
-
end
|
44
|
-
end
|
45
|
-
end
|
46
|
-
end
|
@@ -1,30 +0,0 @@
|
|
1
|
-
# frozen_string_literal: true
|
2
|
-
|
3
|
-
require 'svmkit/validation'
|
4
|
-
require 'svmkit/base/evaluator'
|
5
|
-
|
6
|
-
module SVMKit
|
7
|
-
module EvaluationMeasure
|
8
|
-
# MeanAbsoluteError is a class that calculates the mean absolute error.
|
9
|
-
#
|
10
|
-
# @example
|
11
|
-
# evaluator = SVMKit::EvaluationMeasure::MeanAbsoluteError.new
|
12
|
-
# puts evaluator.score(ground_truth, predicted)
|
13
|
-
class MeanAbsoluteError
|
14
|
-
include Base::Evaluator
|
15
|
-
|
16
|
-
# Calculate mean absolute error.
|
17
|
-
#
|
18
|
-
# @param y_true [Numo::DFloat] (shape: [n_samples, n_outputs]) Ground truth target values.
|
19
|
-
# @param y_pred [Numo::DFloat] (shape: [n_samples, n_outputs]) Estimated target values.
|
20
|
-
# @return [Float] Mean absolute error
|
21
|
-
def score(y_true, y_pred)
|
22
|
-
SVMKit::Validation.check_tvalue_array(y_true)
|
23
|
-
SVMKit::Validation.check_tvalue_array(y_pred)
|
24
|
-
raise ArgumentError, 'Expect to have the same size both y_true and y_pred.' unless y_true.shape == y_pred.shape
|
25
|
-
|
26
|
-
(y_true - y_pred).abs.mean
|
27
|
-
end
|
28
|
-
end
|
29
|
-
end
|
30
|
-
end
|
@@ -1,30 +0,0 @@
|
|
1
|
-
# frozen_string_literal: true
|
2
|
-
|
3
|
-
require 'svmkit/validation'
|
4
|
-
require 'svmkit/base/evaluator'
|
5
|
-
|
6
|
-
module SVMKit
|
7
|
-
module EvaluationMeasure
|
8
|
-
# MeanSquaredError is a class that calculates the mean squared error.
|
9
|
-
#
|
10
|
-
# @example
|
11
|
-
# evaluator = SVMKit::EvaluationMeasure::MeanSquaredError.new
|
12
|
-
# puts evaluator.score(ground_truth, predicted)
|
13
|
-
class MeanSquaredError
|
14
|
-
include Base::Evaluator
|
15
|
-
|
16
|
-
# Calculate mean squared error.
|
17
|
-
#
|
18
|
-
# @param y_true [Numo::DFloat] (shape: [n_samples, n_outputs]) Ground truth target values.
|
19
|
-
# @param y_pred [Numo::DFloat] (shape: [n_samples, n_outputs]) Estimated target values.
|
20
|
-
# @return [Float] Mean squared error
|
21
|
-
def score(y_true, y_pred)
|
22
|
-
SVMKit::Validation.check_tvalue_array(y_true)
|
23
|
-
SVMKit::Validation.check_tvalue_array(y_pred)
|
24
|
-
raise ArgumentError, 'Expect to have the same size both y_true and y_pred.' unless y_true.shape == y_pred.shape
|
25
|
-
|
26
|
-
((y_true - y_pred)**2).mean
|
27
|
-
end
|
28
|
-
end
|
29
|
-
end
|
30
|
-
end
|
@@ -1,63 +0,0 @@
|
|
1
|
-
# frozen_string_literal: true
|
2
|
-
|
3
|
-
require 'svmkit/validation'
|
4
|
-
require 'svmkit/base/evaluator'
|
5
|
-
|
6
|
-
module SVMKit
|
7
|
-
module EvaluationMeasure
|
8
|
-
# NormalizedMutualInformation is a class that calculates the normalized mutual information of cluatering results.
|
9
|
-
#
|
10
|
-
# @example
|
11
|
-
# evaluator = SVMKit::EvaluationMeasure::NormalizedMutualInformation.new
|
12
|
-
# puts evaluator.score(ground_truth, predicted)
|
13
|
-
#
|
14
|
-
# *Reference*
|
15
|
-
# - C D. Manning, P. Raghavan, and H. Schutze, "Introduction to Information Retrieval," Cambridge University Press., 2008.
|
16
|
-
# - N X. Vinh, J. Epps, and J. Bailey, "Information Theoretic Measures for Clusterings Comparison: Variants, Properties, Normalization and Correction for Chance," J. Machine Learning Research, vol. 11, pp. 2837--1854, 2010.
|
17
|
-
class NormalizedMutualInformation
|
18
|
-
include Base::Evaluator
|
19
|
-
|
20
|
-
# Calculate noramlzied mutual information
|
21
|
-
#
|
22
|
-
# @param y_true [Numo::Int32] (shape: [n_samples]) Ground truth labels.
|
23
|
-
# @param y_pred [Numo::Int32] (shape: [n_samples]) Predicted cluster labels.
|
24
|
-
# @return [Float] Normalized mutual information
|
25
|
-
def score(y_true, y_pred)
|
26
|
-
SVMKit::Validation.check_label_array(y_true)
|
27
|
-
SVMKit::Validation.check_label_array(y_pred)
|
28
|
-
# initiazlie some variables.
|
29
|
-
mutual_information = 0.0
|
30
|
-
n_samples = y_pred.size
|
31
|
-
class_ids = y_true.to_a.uniq
|
32
|
-
cluster_ids = y_pred.to_a.uniq
|
33
|
-
# calculate entropy.
|
34
|
-
class_entropy = -1.0 * class_ids.map do |k|
|
35
|
-
ratio = y_true.eq(k).count.fdiv(n_samples)
|
36
|
-
ratio * Math.log(ratio)
|
37
|
-
end.reduce(:+)
|
38
|
-
return 0.0 if class_entropy.zero?
|
39
|
-
cluster_entropy = -1.0 * cluster_ids.map do |k|
|
40
|
-
ratio = y_pred.eq(k).count.fdiv(n_samples)
|
41
|
-
ratio * Math.log(ratio)
|
42
|
-
end.reduce(:+)
|
43
|
-
return 0.0 if cluster_entropy.zero?
|
44
|
-
# calculate mutual information.
|
45
|
-
cluster_ids.map do |k|
|
46
|
-
pr_sample_ids = y_pred.eq(k).where.to_a
|
47
|
-
n_pr_samples = pr_sample_ids.size
|
48
|
-
class_ids.map do |j|
|
49
|
-
tr_sample_ids = y_true.eq(j).where.to_a
|
50
|
-
n_tr_samples = tr_sample_ids.size
|
51
|
-
n_intr_samples = (pr_sample_ids & tr_sample_ids).size
|
52
|
-
if n_intr_samples > 0
|
53
|
-
mutual_information +=
|
54
|
-
n_intr_samples.fdiv(n_samples) * Math.log((n_samples * n_intr_samples).fdiv(n_pr_samples * n_tr_samples))
|
55
|
-
end
|
56
|
-
end
|
57
|
-
end
|
58
|
-
# return normalized mutual information.
|
59
|
-
mutual_information / Math.sqrt(class_entropy * cluster_entropy)
|
60
|
-
end
|
61
|
-
end
|
62
|
-
end
|
63
|
-
end
|
@@ -1,51 +0,0 @@
|
|
1
|
-
# frozen_string_literal: true
|
2
|
-
|
3
|
-
require 'svmkit/validation'
|
4
|
-
require 'svmkit/base/evaluator'
|
5
|
-
require 'svmkit/evaluation_measure/precision_recall'
|
6
|
-
|
7
|
-
module SVMKit
|
8
|
-
# This module consists of the classes for model evaluation.
|
9
|
-
module EvaluationMeasure
|
10
|
-
# Precision is a class that calculates the preicision of the predicted labels.
|
11
|
-
#
|
12
|
-
# @example
|
13
|
-
# evaluator = SVMKit::EvaluationMeasure::Precision.new
|
14
|
-
# puts evaluator.score(ground_truth, predicted)
|
15
|
-
class Precision
|
16
|
-
include Base::Evaluator
|
17
|
-
include EvaluationMeasure::PrecisionRecall
|
18
|
-
|
19
|
-
# Return the average type for calculation of precision.
|
20
|
-
# @return [String] ('binary', 'micro', 'macro')
|
21
|
-
attr_reader :average
|
22
|
-
|
23
|
-
# Create a new evaluation measure calculater for precision score.
|
24
|
-
#
|
25
|
-
# @param average [String] The average type ('binary', 'micro', 'macro')
|
26
|
-
def initialize(average: 'binary')
|
27
|
-
SVMKit::Validation.check_params_string(average: average)
|
28
|
-
@average = average
|
29
|
-
end
|
30
|
-
|
31
|
-
# Calculate average precision.
|
32
|
-
#
|
33
|
-
# @param y_true [Numo::Int32] (shape: [n_samples]) Ground truth labels.
|
34
|
-
# @param y_pred [Numo::Int32] (shape: [n_samples]) Predicted labels.
|
35
|
-
# @return [Float] Average precision
|
36
|
-
def score(y_true, y_pred)
|
37
|
-
SVMKit::Validation.check_label_array(y_true)
|
38
|
-
SVMKit::Validation.check_label_array(y_pred)
|
39
|
-
|
40
|
-
case @average
|
41
|
-
when 'binary'
|
42
|
-
precision_each_class(y_true, y_pred).last
|
43
|
-
when 'micro'
|
44
|
-
micro_average_precision(y_true, y_pred)
|
45
|
-
when 'macro'
|
46
|
-
macro_average_precision(y_true, y_pred)
|
47
|
-
end
|
48
|
-
end
|
49
|
-
end
|
50
|
-
end
|
51
|
-
end
|