rumale 0.8.0 → 0.8.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +186 -0
- data/README.md +2 -2
- data/lib/rumale.rb +5 -0
- data/lib/rumale/evaluation_measure/adjusted_rand_score.rb +74 -0
- data/lib/rumale/evaluation_measure/explained_variance_score.rb +39 -0
- data/lib/rumale/evaluation_measure/mean_squared_log_error.rb +29 -0
- data/lib/rumale/evaluation_measure/median_absolute_error.rb +30 -0
- data/lib/rumale/evaluation_measure/mutual_information.rb +49 -0
- data/lib/rumale/evaluation_measure/normalized_mutual_information.rb +19 -30
- data/lib/rumale/model_selection/cross_validation.rb +1 -1
- data/lib/rumale/version.rb +1 -1
- data/rumale.gemspec +1 -1
- metadata +8 -3
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA1:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: dab9c67aa39f19e73859d41013363b4f3811142e
|
4
|
+
data.tar.gz: 49b1d14b9261f2ede4dc97b4353efdf9032872d2
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 1d2b62e0660586f4ace811f06bdd73e9ff5adb877682a9ae53e29c76188b6fe1215d1305953d292668dcc4f0b5fba71399ad746e8976572a19bbd6a4c1153829
|
7
|
+
data.tar.gz: ce075327208560af72f0b54d7113b39a23c98c1828a3f5c3f5b28b3d987be3d81af465aae943aa27f2883cc671337b1952b4f8d8981a970c3727e0c94affbef5
|
data/CHANGELOG.md
CHANGED
@@ -1,2 +1,188 @@
|
|
1
|
+
# 0.8.1
|
2
|
+
- Add some evaluator classes.
|
3
|
+
- MeanSquaredLogError
|
4
|
+
- MedianAbsoluteError
|
5
|
+
- ExplainedVarianceScore
|
6
|
+
- AdjustedRandScore
|
7
|
+
- MutualInformation
|
8
|
+
- Refactor normalized mutual infomation evaluator.
|
9
|
+
- Fix typo on document ([#2](https://github.com/yoshoku/rumale/pull/2)).
|
10
|
+
|
1
11
|
# 0.8.0
|
12
|
+
## Breaking changes
|
2
13
|
- Rename SVMKit to Rumale.
|
14
|
+
- Rename SGDLienareEstimator class to BaseLienarModel class.
|
15
|
+
- Add data type option to load_libsvm_file method. By default, the method represents the feature with Numo::DFloat.
|
16
|
+
|
17
|
+
## Refactoring
|
18
|
+
- Refactor factorization machine estimators.
|
19
|
+
- Refactor decision tree estimators.
|
20
|
+
|
21
|
+
# 0.7.3
|
22
|
+
- Add class for grid search performing hyperparameter optimization.
|
23
|
+
- Add argument validations to Pipeline.
|
24
|
+
|
25
|
+
# 0.7.2
|
26
|
+
- Add class for Pipeline that constructs chain of transformers and estimators.
|
27
|
+
- Fix some typos on document ([#1](https://github.com/yoshoku/SVMKit/pull/1)).
|
28
|
+
|
29
|
+
# 0.7.1
|
30
|
+
- Fix to use CSV class in parsing libsvm format file.
|
31
|
+
- Refactor ensemble estimators.
|
32
|
+
|
33
|
+
# 0.7.0
|
34
|
+
- Add class for AdaBoost classifier.
|
35
|
+
- Add class for AdaBoost regressor.
|
36
|
+
|
37
|
+
# 0.6.3
|
38
|
+
- Fix bug on setting random seed and max_features parameter of Random Forest estimators.
|
39
|
+
|
40
|
+
# 0.6.2
|
41
|
+
- Refactor decision tree classes for improving performance.
|
42
|
+
|
43
|
+
# 0.6.1
|
44
|
+
- Add abstract class for linear estimator with stochastic gradient descent.
|
45
|
+
- Refactor linear estimators to use linear esitmator abstract class.
|
46
|
+
- Refactor decision tree classes to avoid unneeded type conversion.
|
47
|
+
|
48
|
+
# 0.6.0
|
49
|
+
- Add class for Principal Component Analysis.
|
50
|
+
- Add class for Non-negative Matrix Factorization.
|
51
|
+
|
52
|
+
# 0.5.2
|
53
|
+
- Add class for DBSCAN clustering.
|
54
|
+
|
55
|
+
# 0.5.1
|
56
|
+
- Fix bug on class probability calculation of DecisionTreeClassifier.
|
57
|
+
|
58
|
+
# 0.5.0
|
59
|
+
- Add class for K-Means clustering.
|
60
|
+
- Add class for evaluating purity.
|
61
|
+
- Add class for evaluating normalized mutual information.
|
62
|
+
|
63
|
+
# 0.4.1
|
64
|
+
- Add class for linear regressor.
|
65
|
+
- Add class for SGD optimizer.
|
66
|
+
- Add class for RMSProp optimizer.
|
67
|
+
- Add class for YellowFin optimizer.
|
68
|
+
- Fix to be able to select optimizer on estimators of LineaModel and PolynomialModel.
|
69
|
+
|
70
|
+
# 0.4.0
|
71
|
+
## Breaking changes
|
72
|
+
|
73
|
+
SVMKit introduces optimizer algorithm that calculates learning rates adaptively
|
74
|
+
on each iteration of stochastic gradient descent (SGD).
|
75
|
+
While Pegasos SGD runs fast, it sometimes fails to optimize complicated models
|
76
|
+
like Factorization Machine.
|
77
|
+
To solve this problem, in version 0.3.3, SVMKit introduced optimization with RMSProp on
|
78
|
+
FactorizationMachineRegressor, Ridge and Lasso.
|
79
|
+
This attempt realized stable optimization of those estimators.
|
80
|
+
Following the success of the attempt, author decided to use modern optimizer algorithms
|
81
|
+
with all SGD optimizations in SVMKit.
|
82
|
+
Through some preliminary experiments, author implemented Nadam as the default optimizer.
|
83
|
+
SVMKit plans to add other optimizer algorithms sequentially, so that users can select them.
|
84
|
+
|
85
|
+
- Fix to use Nadam for optimization on SVC, SVR, LogisticRegression, Ridge, Lasso, and Factorization Machine estimators.
|
86
|
+
- Combine reg_param_weight and reg_param_bias parameters on Factorization Machine estimators into the unified parameter named reg_param_linear.
|
87
|
+
- Remove init_std paramter on Factorization Machine estimators.
|
88
|
+
- Remove learning_rate, decay, and momentum parameters on Ridge, Lasso, and FactorizationMachineRegressor.
|
89
|
+
- Remove normalize parameter on SVC, SVR, and LogisticRegression.
|
90
|
+
|
91
|
+
# 0.3.3
|
92
|
+
- Add class for Ridge regressor.
|
93
|
+
- Add class for Lasso regressor.
|
94
|
+
- Fix bug on gradient calculation of FactorizationMachineRegressor.
|
95
|
+
- Fix some documents.
|
96
|
+
|
97
|
+
# 0.3.2
|
98
|
+
- Add class for Factorization Machine regressor.
|
99
|
+
- Add class for Decision Tree regressor.
|
100
|
+
- Add class for Random Forest regressor.
|
101
|
+
- Fix to support loading and dumping libsvm file with multi-target variables.
|
102
|
+
- Fix to require DecisionTreeClassifier on RandomForestClassifier.
|
103
|
+
- Fix some mistakes on document.
|
104
|
+
|
105
|
+
# 0.3.1
|
106
|
+
- Fix bug on decision function calculation of FactorizationMachineClassifier.
|
107
|
+
- Fix bug on weight updating process of KernelSVC.
|
108
|
+
|
109
|
+
# 0.3.0
|
110
|
+
- Add class for Support Vector Regression.
|
111
|
+
- Add class for K-Nearest Neighbor Regression.
|
112
|
+
- Add class for evaluating coefficient of determination.
|
113
|
+
- Add class for evaluating mean squared error.
|
114
|
+
- Add class for evaluating mean absolute error.
|
115
|
+
- Fix to use min method instead of sort and first methods.
|
116
|
+
- Fix cross validation class to be able to use for regression problem.
|
117
|
+
- Fix some typos on document.
|
118
|
+
- Rename spec filename for Factorization Machine classifier.
|
119
|
+
|
120
|
+
# 0.2.9
|
121
|
+
- Add predict_proba method to SVC and KernelSVC.
|
122
|
+
- Add class for evaluating logarithmic loss.
|
123
|
+
- Add classes for Label- and One-Hot- encoding.
|
124
|
+
- Add some validator.
|
125
|
+
- Fix bug on training data score calculation of cross validation.
|
126
|
+
- Fix fit method of SVC for performance.
|
127
|
+
- Fix criterion calculation on Decision Tree for performance.
|
128
|
+
- Fix data structure of Decision Tree for performance.
|
129
|
+
|
130
|
+
# 0.2.8
|
131
|
+
- Fix bug on gradient calculation of Logistic Regression.
|
132
|
+
- Fix to change accessor of params of estimators to read only.
|
133
|
+
- Add parameter validation.
|
134
|
+
|
135
|
+
# 0.2.7
|
136
|
+
- Fix to support multiclass classifiction into LinearSVC, LogisticRegression, KernelSVC, and FactorizationMachineClassifier
|
137
|
+
|
138
|
+
# 0.2.6
|
139
|
+
- Add class for Decision Tree classifier.
|
140
|
+
- Add class for Random Forest classifier.
|
141
|
+
- Fix to use frozen string literal.
|
142
|
+
- Refactor marshal dump method on some classes.
|
143
|
+
- Introduce Coveralls to confirm test coverage.
|
144
|
+
|
145
|
+
# 0.2.5
|
146
|
+
- Add classes for Naive Bayes classifier.
|
147
|
+
- Fix decision function method on Logistic Regression class.
|
148
|
+
- Fix method visibility on RBF kernel approximation class.
|
149
|
+
|
150
|
+
# 0.2.4
|
151
|
+
- Add class for Factorization Machine classifier.
|
152
|
+
- Add classes for evaluation measures.
|
153
|
+
- Fix the method for prediction of class probability in Logistic Regression.
|
154
|
+
|
155
|
+
# 0.2.3
|
156
|
+
- Add class for cross validation.
|
157
|
+
- Add specs for base modules.
|
158
|
+
- Fix validation of the number of splits when a negative label is given.
|
159
|
+
|
160
|
+
# 0.2.2
|
161
|
+
- Add data splitter classes for K-fold cross validation.
|
162
|
+
|
163
|
+
# 0.2.1
|
164
|
+
- Add class for K-nearest neighbors classifier.
|
165
|
+
|
166
|
+
# 0.2.0
|
167
|
+
- Migrated the linear algebra library to Numo::NArray.
|
168
|
+
- Add module for loading and saving libsvm format file.
|
169
|
+
|
170
|
+
# 0.1.3
|
171
|
+
- Add class for Kernel Support Vector Machine with Pegasos algorithm.
|
172
|
+
- Add module for calculating pairwise kernel fuctions and euclidean distances.
|
173
|
+
|
174
|
+
# 0.1.2
|
175
|
+
- Add the function learning a model with bias term to the PegasosSVC and LogisticRegression classes.
|
176
|
+
- Rewrite the document with yard notation.
|
177
|
+
|
178
|
+
# 0.1.1
|
179
|
+
- Add class for Logistic Regression with SGD optimization.
|
180
|
+
- Fix some mistakes on the document.
|
181
|
+
|
182
|
+
# 0.1.0
|
183
|
+
- Add basic classes.
|
184
|
+
- Add an utility module.
|
185
|
+
- Add class for RBF kernel approximation.
|
186
|
+
- Add class for Support Vector Machine with Pegasos alogrithm.
|
187
|
+
- Add class that performs mutlclass classification with one-vs.-rest strategy.
|
188
|
+
- Add classes for preprocessing such as min-max scaling, standardization, and L2 normalization.
|
data/README.md
CHANGED
@@ -13,7 +13,7 @@ Logistic Regression, Linear Regression, Ridge, Lasso, Factorization Machine,
|
|
13
13
|
Naive Bayes, Decision Tree, AdaBoost, Random Forest, K-nearest neighbor classifier,
|
14
14
|
K-Means, DBSCAN, Principal Component Analysis, and Non-negative Matrix Factorization.
|
15
15
|
|
16
|
-
This project was formerly known as "
|
16
|
+
This project was formerly known as "SVMKit".
|
17
17
|
If you are using SVMKit, please install Rumale and replace `SVMKit` constants with `Rumale`.
|
18
18
|
|
19
19
|
## Installation
|
@@ -161,7 +161,7 @@ To install this gem onto your local machine, run `bundle exec rake install`. To
|
|
161
161
|
|
162
162
|
## Contributing
|
163
163
|
|
164
|
-
Bug reports and pull requests are welcome on GitHub at https://github.com/yoshoku/
|
164
|
+
Bug reports and pull requests are welcome on GitHub at https://github.com/yoshoku/rumale.
|
165
165
|
This project is intended to be a safe, welcoming space for collaboration,
|
166
166
|
and contributors are expected to adhere to the [Contributor Covenant](http://contributor-covenant.org) code of conduct.
|
167
167
|
|
data/lib/rumale.rb
CHANGED
@@ -64,7 +64,12 @@ require 'rumale/evaluation_measure/recall'
|
|
64
64
|
require 'rumale/evaluation_measure/f_score'
|
65
65
|
require 'rumale/evaluation_measure/log_loss'
|
66
66
|
require 'rumale/evaluation_measure/r2_score'
|
67
|
+
require 'rumale/evaluation_measure/explained_variance_score'
|
67
68
|
require 'rumale/evaluation_measure/mean_squared_error'
|
69
|
+
require 'rumale/evaluation_measure/mean_squared_log_error'
|
68
70
|
require 'rumale/evaluation_measure/mean_absolute_error'
|
71
|
+
require 'rumale/evaluation_measure/median_absolute_error'
|
72
|
+
require 'rumale/evaluation_measure/adjusted_rand_score'
|
69
73
|
require 'rumale/evaluation_measure/purity'
|
74
|
+
require 'rumale/evaluation_measure/mutual_information'
|
70
75
|
require 'rumale/evaluation_measure/normalized_mutual_information'
|
@@ -0,0 +1,74 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'rumale/base/evaluator'
|
4
|
+
|
5
|
+
module Rumale
|
6
|
+
module EvaluationMeasure
|
7
|
+
# AdjustedRandScore is a class that calculates the adjusted rand index.
|
8
|
+
#
|
9
|
+
# @example
|
10
|
+
# evaluator = Rumale::EvaluationMeasure::AdjustedRandScore.new
|
11
|
+
# puts evaluator.score(ground_truth, predicted)
|
12
|
+
#
|
13
|
+
# *Reference*
|
14
|
+
# - N X. Vinh, J. Epps, and J. Bailey, "Information Theoretic Measures for Clusterings Comparison: Variants, Properties, Normalization and Correction for Chance", J. Machine Learnig Research, Vol. 11, pp.2837--2854, 2010.
|
15
|
+
class AdjustedRandScore
|
16
|
+
include Base::Evaluator
|
17
|
+
|
18
|
+
# Calculate adjusted rand index.
|
19
|
+
#
|
20
|
+
# @param y_true [Numo::Int32] (shape: [n_samples]) Ground truth labels.
|
21
|
+
# @param y_pred [Numo::Int32] (shape: [n_samples]) Predicted cluster labels.
|
22
|
+
# @return [Float] Adjusted rand index.
|
23
|
+
def score(y_true, y_pred)
|
24
|
+
check_label_array(y_true)
|
25
|
+
check_label_array(y_pred)
|
26
|
+
|
27
|
+
# initiazlie some variables.
|
28
|
+
n_samples = y_pred.size
|
29
|
+
n_classes = y_true.to_a.uniq.size
|
30
|
+
n_clusters = y_pred.to_a.uniq.size
|
31
|
+
|
32
|
+
# check special cases.
|
33
|
+
return 1.0 if special_cases?(n_samples, n_classes, n_clusters)
|
34
|
+
|
35
|
+
# calculate adjusted rand index.
|
36
|
+
table = contingency_table(y_true, y_pred)
|
37
|
+
sum_comb_a = table.sum(axis: 1).map { |v| comb_two(v) }.sum
|
38
|
+
sum_comb_b = table.sum(axis: 0).map { |v| comb_two(v) }.sum
|
39
|
+
sum_comb = table.flatten.map { |v| comb_two(v) }.sum
|
40
|
+
prod_comb = (sum_comb_a * sum_comb_b).fdiv(comb_two(n_samples))
|
41
|
+
mean_comb = (sum_comb_a + sum_comb_b).fdiv(2)
|
42
|
+
(sum_comb - prod_comb).fdiv(mean_comb - prod_comb)
|
43
|
+
end
|
44
|
+
|
45
|
+
private
|
46
|
+
|
47
|
+
def contingency_table(y_true, y_pred)
|
48
|
+
class_ids = y_true.to_a.uniq
|
49
|
+
cluster_ids = y_pred.to_a.uniq
|
50
|
+
n_classes = class_ids.size
|
51
|
+
n_clusters = cluster_ids.size
|
52
|
+
table = Numo::Int32.zeros(n_classes, n_clusters)
|
53
|
+
n_classes.times do |i|
|
54
|
+
b_true = y_true.eq(class_ids[i])
|
55
|
+
n_clusters.times do |j|
|
56
|
+
b_pred = y_pred.eq(cluster_ids[j])
|
57
|
+
table[i, j] = (b_true & b_pred).count
|
58
|
+
end
|
59
|
+
end
|
60
|
+
table
|
61
|
+
end
|
62
|
+
|
63
|
+
def special_cases?(n_samples, n_classes, n_clusters)
|
64
|
+
((n_classes.zero? && n_clusters.zero?) ||
|
65
|
+
(n_classes == 1 && n_clusters == 1) ||
|
66
|
+
(n_classes == n_samples && n_clusters == n_samples))
|
67
|
+
end
|
68
|
+
|
69
|
+
def comb_two(k)
|
70
|
+
k * (k - 1) / 2
|
71
|
+
end
|
72
|
+
end
|
73
|
+
end
|
74
|
+
end
|
@@ -0,0 +1,39 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'rumale/base/evaluator'
|
4
|
+
|
5
|
+
module Rumale
|
6
|
+
module EvaluationMeasure
|
7
|
+
# ExplainedVarianceScore is a class that calculates the explained variance score.
|
8
|
+
#
|
9
|
+
# @example
|
10
|
+
# evaluator = Rumale::EvaluationMeasure::ExplainedVarianceScore.new
|
11
|
+
# puts evaluator.score(ground_truth, predicted)
|
12
|
+
class ExplainedVarianceScore
|
13
|
+
include Base::Evaluator
|
14
|
+
|
15
|
+
# Calculate explained variance score.
|
16
|
+
#
|
17
|
+
# @param y_true [Numo::DFloat] (shape: [n_samples, n_outputs]) Ground truth target values.
|
18
|
+
# @param y_pred [Numo::DFloat] (shape: [n_samples, n_outputs]) Estimated target values.
|
19
|
+
# @return [Float] Explained variance score.
|
20
|
+
def score(y_true, y_pred)
|
21
|
+
check_tvalue_array(y_true)
|
22
|
+
check_tvalue_array(y_pred)
|
23
|
+
raise ArgumentError, 'Expect to have the same size both y_true and y_pred.' unless y_true.shape == y_pred.shape
|
24
|
+
|
25
|
+
diff = y_true - y_pred
|
26
|
+
numerator = ((diff - diff.mean(0))**2).mean(0)
|
27
|
+
denominator = ((y_true - y_true.mean(0))**2).mean(0)
|
28
|
+
|
29
|
+
n_outputs = y_true.shape[1]
|
30
|
+
if n_outputs.nil?
|
31
|
+
denominator.zero? ? 0 : 1.0 - numerator / denominator
|
32
|
+
else
|
33
|
+
valids = denominator.ne(0)
|
34
|
+
(1.0 - numerator[valids] / denominator[valids]).sum / n_outputs
|
35
|
+
end
|
36
|
+
end
|
37
|
+
end
|
38
|
+
end
|
39
|
+
end
|
@@ -0,0 +1,29 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'rumale/base/evaluator'
|
4
|
+
|
5
|
+
module Rumale
|
6
|
+
module EvaluationMeasure
|
7
|
+
# MeanSquaredLogError is a class that calculates the mean squared logarithmic error.
|
8
|
+
#
|
9
|
+
# @example
|
10
|
+
# evaluator = Rumale::EvaluationMeasure::MeanSquaredError.new
|
11
|
+
# puts evaluator.score(ground_truth, predicted)
|
12
|
+
class MeanSquaredLogError
|
13
|
+
include Base::Evaluator
|
14
|
+
|
15
|
+
# Calculate mean squared logarithmic error.
|
16
|
+
#
|
17
|
+
# @param y_true [Numo::DFloat] (shape: [n_samples, n_outputs]) Ground truth target values.
|
18
|
+
# @param y_pred [Numo::DFloat] (shape: [n_samples, n_outputs]) Estimated target values.
|
19
|
+
# @return [Float] Mean squared logarithmic error.
|
20
|
+
def score(y_true, y_pred)
|
21
|
+
check_tvalue_array(y_true)
|
22
|
+
check_tvalue_array(y_pred)
|
23
|
+
raise ArgumentError, 'Expect to have the same size both y_true and y_pred.' unless y_true.shape == y_pred.shape
|
24
|
+
|
25
|
+
((Numo::NMath.log(y_true + 1) - Numo::NMath.log(y_pred + 1))**2).mean
|
26
|
+
end
|
27
|
+
end
|
28
|
+
end
|
29
|
+
end
|
@@ -0,0 +1,30 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'rumale/base/evaluator'
|
4
|
+
|
5
|
+
module Rumale
|
6
|
+
module EvaluationMeasure
|
7
|
+
# MedianAbsoluteError is a class that calculates the median absolute error.
|
8
|
+
#
|
9
|
+
# @example
|
10
|
+
# evaluator = Rumale::EvaluationMeasure::MedianAbsoluteError.new
|
11
|
+
# puts evaluator.score(ground_truth, predicted)
|
12
|
+
class MedianAbsoluteError
|
13
|
+
include Base::Evaluator
|
14
|
+
|
15
|
+
# Calculate median absolute error.
|
16
|
+
#
|
17
|
+
# @param y_true [Numo::DFloat] (shape: [n_samples]) Ground truth target values.
|
18
|
+
# @param y_pred [Numo::DFloat] (shape: [n_samples]) Estimated target values.
|
19
|
+
# @return [Float] Median absolute error.
|
20
|
+
def score(y_true, y_pred)
|
21
|
+
check_tvalue_array(y_true)
|
22
|
+
check_tvalue_array(y_pred)
|
23
|
+
raise ArgumentError, 'Expect to have the same size both y_true and y_pred.' unless y_true.shape == y_pred.shape
|
24
|
+
raise ArgumentError, 'Expect target values to be 1-D arrray' if [y_true.shape.size, y_pred.shape.size].max > 1
|
25
|
+
|
26
|
+
(y_true - y_pred).abs.median
|
27
|
+
end
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
@@ -0,0 +1,49 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'rumale/base/evaluator'
|
4
|
+
|
5
|
+
module Rumale
|
6
|
+
module EvaluationMeasure
|
7
|
+
# MutualInformation is a class that calculates the mutual information.
|
8
|
+
#
|
9
|
+
# @example
|
10
|
+
# evaluator = Rumale::EvaluationMeasure::MutualInformation.new
|
11
|
+
# puts evaluator.score(ground_truth, predicted)
|
12
|
+
#
|
13
|
+
# *Reference*
|
14
|
+
# - N X. Vinh, J. Epps, and J. Bailey, "Information Theoretic Measures for Clusterings Comparison: Variants, Properties, Normalization and Correction for Chance," J. Machine Learning Research, vol. 11, pp. 2837--1854, 2010.
|
15
|
+
class MutualInformation
|
16
|
+
include Base::Evaluator
|
17
|
+
|
18
|
+
# Calculate mutual information
|
19
|
+
#
|
20
|
+
# @param y_true [Numo::Int32] (shape: [n_samples]) Ground truth labels.
|
21
|
+
# @param y_pred [Numo::Int32] (shape: [n_samples]) Predicted cluster labels.
|
22
|
+
# @return [Float] Mutual information.
|
23
|
+
def score(y_true, y_pred)
|
24
|
+
check_label_array(y_true)
|
25
|
+
check_label_array(y_pred)
|
26
|
+
# initiazlie some variables.
|
27
|
+
mutual_information = 0.0
|
28
|
+
n_samples = y_pred.size
|
29
|
+
class_ids = y_true.to_a.uniq
|
30
|
+
cluster_ids = y_pred.to_a.uniq
|
31
|
+
# calculate mutual information.
|
32
|
+
cluster_ids.map do |k|
|
33
|
+
pr_sample_ids = y_pred.eq(k).where.to_a
|
34
|
+
n_pr_samples = pr_sample_ids.size
|
35
|
+
class_ids.map do |j|
|
36
|
+
tr_sample_ids = y_true.eq(j).where.to_a
|
37
|
+
n_tr_samples = tr_sample_ids.size
|
38
|
+
n_intr_samples = (pr_sample_ids & tr_sample_ids).size
|
39
|
+
if n_intr_samples.positive?
|
40
|
+
mutual_information +=
|
41
|
+
n_intr_samples.fdiv(n_samples) * Math.log((n_samples * n_intr_samples).fdiv(n_pr_samples * n_tr_samples))
|
42
|
+
end
|
43
|
+
end
|
44
|
+
end
|
45
|
+
mutual_information
|
46
|
+
end
|
47
|
+
end
|
48
|
+
end
|
49
|
+
end
|
@@ -1,10 +1,11 @@
|
|
1
1
|
# frozen_string_literal: true
|
2
2
|
|
3
3
|
require 'rumale/base/evaluator'
|
4
|
+
require 'rumale/evaluation_measure/mutual_information'
|
4
5
|
|
5
6
|
module Rumale
|
6
7
|
module EvaluationMeasure
|
7
|
-
# NormalizedMutualInformation is a class that calculates the normalized mutual information
|
8
|
+
# NormalizedMutualInformation is a class that calculates the normalized mutual information.
|
8
9
|
#
|
9
10
|
# @example
|
10
11
|
# evaluator = Rumale::EvaluationMeasure::NormalizedMutualInformation.new
|
@@ -24,38 +25,26 @@ module Rumale
|
|
24
25
|
def score(y_true, y_pred)
|
25
26
|
check_label_array(y_true)
|
26
27
|
check_label_array(y_pred)
|
27
|
-
#
|
28
|
-
|
29
|
-
n_samples = y_pred.size
|
30
|
-
class_ids = y_true.to_a.uniq
|
31
|
-
cluster_ids = y_pred.to_a.uniq
|
32
|
-
# calculate entropy.
|
33
|
-
class_entropy = -1.0 * class_ids.map do |k|
|
34
|
-
ratio = y_true.eq(k).count.fdiv(n_samples)
|
35
|
-
ratio * Math.log(ratio)
|
36
|
-
end.reduce(:+)
|
28
|
+
# calculate entropies.
|
29
|
+
class_entropy = entropy(y_true)
|
37
30
|
return 0.0 if class_entropy.zero?
|
38
|
-
cluster_entropy =
|
39
|
-
ratio = y_pred.eq(k).count.fdiv(n_samples)
|
40
|
-
ratio * Math.log(ratio)
|
41
|
-
end.reduce(:+)
|
31
|
+
cluster_entropy = entropy(y_pred)
|
42
32
|
return 0.0 if cluster_entropy.zero?
|
43
33
|
# calculate mutual information.
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
end
|
57
|
-
|
58
|
-
mutual_information / Math.sqrt(class_entropy * cluster_entropy)
|
34
|
+
mi = MutualInformation.new
|
35
|
+
mi.score(y_true, y_pred) / Math.sqrt(class_entropy * cluster_entropy)
|
36
|
+
end
|
37
|
+
|
38
|
+
private
|
39
|
+
|
40
|
+
def entropy(y)
|
41
|
+
n_samples = y.size
|
42
|
+
indices = y.to_a.uniq
|
43
|
+
sum_log = indices.map do |k|
|
44
|
+
ratio = y.eq(k).count.fdiv(n_samples)
|
45
|
+
ratio * Math.log(ratio)
|
46
|
+
end.reduce(:+)
|
47
|
+
-sum_log
|
59
48
|
end
|
60
49
|
end
|
61
50
|
end
|
@@ -17,7 +17,7 @@ module Rumale
|
|
17
17
|
# svc = Rumale::LinearModel::SVC.new
|
18
18
|
# kf = Rumale::ModelSelection::StratifiedKFold.new(n_splits: 5)
|
19
19
|
# cv = Rumale::ModelSelection::CrossValidation.new(estimator: svc, splitter: kf)
|
20
|
-
# report = cv.perform(samples,
|
20
|
+
# report = cv.perform(samples, labels)
|
21
21
|
# mean_test_score = report[:test_score].inject(:+) / kf.n_splits
|
22
22
|
#
|
23
23
|
class CrossValidation
|
data/lib/rumale/version.rb
CHANGED
data/rumale.gemspec
CHANGED
@@ -20,7 +20,7 @@ Logistic Regression, Linear Regression, Ridge, Lasso, Factorization Machine,
|
|
20
20
|
Naive Bayes, Decision Tree, AdaBoost, Random Forest, K-nearest neighbor algorithm,
|
21
21
|
K-Means, DBSCAN, Principal Component Analysis, and Non-negative Matrix Factorization.
|
22
22
|
MSG
|
23
|
-
spec.homepage = 'https://github.com/yoshoku/
|
23
|
+
spec.homepage = 'https://github.com/yoshoku/rumale'
|
24
24
|
spec.license = 'BSD-2-Clause'
|
25
25
|
|
26
26
|
spec.files = `git ls-files -z`.split("\x0").reject do |f|
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: rumale
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.8.
|
4
|
+
version: 0.8.1
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- yoshoku
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2019-03-
|
11
|
+
date: 2019-03-08 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|
@@ -125,10 +125,15 @@ files:
|
|
125
125
|
- lib/rumale/ensemble/random_forest_classifier.rb
|
126
126
|
- lib/rumale/ensemble/random_forest_regressor.rb
|
127
127
|
- lib/rumale/evaluation_measure/accuracy.rb
|
128
|
+
- lib/rumale/evaluation_measure/adjusted_rand_score.rb
|
129
|
+
- lib/rumale/evaluation_measure/explained_variance_score.rb
|
128
130
|
- lib/rumale/evaluation_measure/f_score.rb
|
129
131
|
- lib/rumale/evaluation_measure/log_loss.rb
|
130
132
|
- lib/rumale/evaluation_measure/mean_absolute_error.rb
|
131
133
|
- lib/rumale/evaluation_measure/mean_squared_error.rb
|
134
|
+
- lib/rumale/evaluation_measure/mean_squared_log_error.rb
|
135
|
+
- lib/rumale/evaluation_measure/median_absolute_error.rb
|
136
|
+
- lib/rumale/evaluation_measure/mutual_information.rb
|
132
137
|
- lib/rumale/evaluation_measure/normalized_mutual_information.rb
|
133
138
|
- lib/rumale/evaluation_measure/precision.rb
|
134
139
|
- lib/rumale/evaluation_measure/precision_recall.rb
|
@@ -176,7 +181,7 @@ files:
|
|
176
181
|
- lib/rumale/values.rb
|
177
182
|
- lib/rumale/version.rb
|
178
183
|
- rumale.gemspec
|
179
|
-
homepage: https://github.com/yoshoku/
|
184
|
+
homepage: https://github.com/yoshoku/rumale
|
180
185
|
licenses:
|
181
186
|
- BSD-2-Clause
|
182
187
|
metadata: {}
|