rumale 0.8.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/.coveralls.yml +1 -0
- data/.gitignore +20 -0
- data/.rspec +3 -0
- data/.rubocop.yml +47 -0
- data/.rubocop_todo.yml +58 -0
- data/.travis.yml +13 -0
- data/CHANGELOG.md +2 -0
- data/CODE_OF_CONDUCT.md +74 -0
- data/Gemfile +4 -0
- data/LICENSE.txt +23 -0
- data/README.md +175 -0
- data/Rakefile +6 -0
- data/bin/console +14 -0
- data/bin/setup +8 -0
- data/lib/rumale.rb +70 -0
- data/lib/rumale/base/base_estimator.rb +13 -0
- data/lib/rumale/base/classifier.rb +36 -0
- data/lib/rumale/base/cluster_analyzer.rb +31 -0
- data/lib/rumale/base/evaluator.rb +17 -0
- data/lib/rumale/base/regressor.rb +36 -0
- data/lib/rumale/base/splitter.rb +21 -0
- data/lib/rumale/base/transformer.rb +22 -0
- data/lib/rumale/clustering/dbscan.rb +125 -0
- data/lib/rumale/clustering/k_means.rb +138 -0
- data/lib/rumale/dataset.rb +110 -0
- data/lib/rumale/decomposition/nmf.rb +141 -0
- data/lib/rumale/decomposition/pca.rb +148 -0
- data/lib/rumale/ensemble/ada_boost_classifier.rb +196 -0
- data/lib/rumale/ensemble/ada_boost_regressor.rb +178 -0
- data/lib/rumale/ensemble/random_forest_classifier.rb +180 -0
- data/lib/rumale/ensemble/random_forest_regressor.rb +141 -0
- data/lib/rumale/evaluation_measure/accuracy.rb +29 -0
- data/lib/rumale/evaluation_measure/f_score.rb +50 -0
- data/lib/rumale/evaluation_measure/log_loss.rb +45 -0
- data/lib/rumale/evaluation_measure/mean_absolute_error.rb +29 -0
- data/lib/rumale/evaluation_measure/mean_squared_error.rb +29 -0
- data/lib/rumale/evaluation_measure/normalized_mutual_information.rb +62 -0
- data/lib/rumale/evaluation_measure/precision.rb +50 -0
- data/lib/rumale/evaluation_measure/precision_recall.rb +91 -0
- data/lib/rumale/evaluation_measure/purity.rb +40 -0
- data/lib/rumale/evaluation_measure/r2_score.rb +43 -0
- data/lib/rumale/evaluation_measure/recall.rb +50 -0
- data/lib/rumale/kernel_approximation/rbf.rb +121 -0
- data/lib/rumale/kernel_machine/kernel_svc.rb +193 -0
- data/lib/rumale/linear_model/base_linear_model.rb +89 -0
- data/lib/rumale/linear_model/lasso.rb +136 -0
- data/lib/rumale/linear_model/linear_regression.rb +110 -0
- data/lib/rumale/linear_model/logistic_regression.rb +159 -0
- data/lib/rumale/linear_model/ridge.rb +110 -0
- data/lib/rumale/linear_model/svc.rb +183 -0
- data/lib/rumale/linear_model/svr.rb +122 -0
- data/lib/rumale/model_selection/cross_validation.rb +123 -0
- data/lib/rumale/model_selection/grid_search_cv.rb +247 -0
- data/lib/rumale/model_selection/k_fold.rb +76 -0
- data/lib/rumale/model_selection/stratified_k_fold.rb +94 -0
- data/lib/rumale/multiclass/one_vs_rest_classifier.rb +100 -0
- data/lib/rumale/naive_bayes/naive_bayes.rb +315 -0
- data/lib/rumale/nearest_neighbors/k_neighbors_classifier.rb +111 -0
- data/lib/rumale/nearest_neighbors/k_neighbors_regressor.rb +93 -0
- data/lib/rumale/optimizer/nadam.rb +90 -0
- data/lib/rumale/optimizer/rmsprop.rb +69 -0
- data/lib/rumale/optimizer/sgd.rb +65 -0
- data/lib/rumale/optimizer/yellow_fin.rb +144 -0
- data/lib/rumale/pairwise_metric.rb +91 -0
- data/lib/rumale/pipeline/pipeline.rb +197 -0
- data/lib/rumale/polynomial_model/base_factorization_machine.rb +99 -0
- data/lib/rumale/polynomial_model/factorization_machine_classifier.rb +197 -0
- data/lib/rumale/polynomial_model/factorization_machine_regressor.rb +131 -0
- data/lib/rumale/preprocessing/l2_normalizer.rb +62 -0
- data/lib/rumale/preprocessing/label_encoder.rb +94 -0
- data/lib/rumale/preprocessing/min_max_scaler.rb +92 -0
- data/lib/rumale/preprocessing/one_hot_encoder.rb +98 -0
- data/lib/rumale/preprocessing/standard_scaler.rb +86 -0
- data/lib/rumale/probabilistic_output.rb +112 -0
- data/lib/rumale/tree/base_decision_tree.rb +153 -0
- data/lib/rumale/tree/decision_tree_classifier.rb +163 -0
- data/lib/rumale/tree/decision_tree_regressor.rb +135 -0
- data/lib/rumale/tree/node.rb +70 -0
- data/lib/rumale/utils.rb +37 -0
- data/lib/rumale/validation.rb +79 -0
- data/lib/rumale/values.rb +13 -0
- data/lib/rumale/version.rb +6 -0
- data/rumale.gemspec +41 -0
- metadata +204 -0
data/bin/console
ADDED
@@ -0,0 +1,14 @@
|
|
1
|
+
#!/usr/bin/env ruby
|
2
|
+
|
3
|
+
require 'bundler/setup'
|
4
|
+
require 'rumale'
|
5
|
+
|
6
|
+
# You can add fixtures and/or initialization code here to make experimenting
|
7
|
+
# with your gem easier. You can also use a different console, if you like.
|
8
|
+
|
9
|
+
# (If you use this, don't forget to add pry to your Gemfile!)
|
10
|
+
# require 'pry'
|
11
|
+
# Pry.start
|
12
|
+
|
13
|
+
require 'irb'
|
14
|
+
IRB.start(__FILE__)
|
data/bin/setup
ADDED
data/lib/rumale.rb
ADDED
@@ -0,0 +1,70 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'numo/narray'
|
4
|
+
|
5
|
+
require 'rumale/version'
|
6
|
+
require 'rumale/validation'
|
7
|
+
require 'rumale/values'
|
8
|
+
require 'rumale/utils'
|
9
|
+
require 'rumale/pairwise_metric'
|
10
|
+
require 'rumale/dataset'
|
11
|
+
require 'rumale/probabilistic_output'
|
12
|
+
require 'rumale/base/base_estimator'
|
13
|
+
require 'rumale/base/classifier'
|
14
|
+
require 'rumale/base/regressor'
|
15
|
+
require 'rumale/base/cluster_analyzer'
|
16
|
+
require 'rumale/base/transformer'
|
17
|
+
require 'rumale/base/splitter'
|
18
|
+
require 'rumale/base/evaluator'
|
19
|
+
require 'rumale/optimizer/sgd'
|
20
|
+
require 'rumale/optimizer/rmsprop'
|
21
|
+
require 'rumale/optimizer/nadam'
|
22
|
+
require 'rumale/optimizer/yellow_fin'
|
23
|
+
require 'rumale/pipeline/pipeline'
|
24
|
+
require 'rumale/kernel_approximation/rbf'
|
25
|
+
require 'rumale/linear_model/base_linear_model'
|
26
|
+
require 'rumale/linear_model/svc'
|
27
|
+
require 'rumale/linear_model/svr'
|
28
|
+
require 'rumale/linear_model/logistic_regression'
|
29
|
+
require 'rumale/linear_model/linear_regression'
|
30
|
+
require 'rumale/linear_model/ridge'
|
31
|
+
require 'rumale/linear_model/lasso'
|
32
|
+
require 'rumale/kernel_machine/kernel_svc'
|
33
|
+
require 'rumale/polynomial_model/base_factorization_machine'
|
34
|
+
require 'rumale/polynomial_model/factorization_machine_classifier'
|
35
|
+
require 'rumale/polynomial_model/factorization_machine_regressor'
|
36
|
+
require 'rumale/multiclass/one_vs_rest_classifier'
|
37
|
+
require 'rumale/nearest_neighbors/k_neighbors_classifier'
|
38
|
+
require 'rumale/nearest_neighbors/k_neighbors_regressor'
|
39
|
+
require 'rumale/naive_bayes/naive_bayes'
|
40
|
+
require 'rumale/tree/node'
|
41
|
+
require 'rumale/tree/base_decision_tree'
|
42
|
+
require 'rumale/tree/decision_tree_classifier'
|
43
|
+
require 'rumale/tree/decision_tree_regressor'
|
44
|
+
require 'rumale/ensemble/ada_boost_classifier'
|
45
|
+
require 'rumale/ensemble/ada_boost_regressor'
|
46
|
+
require 'rumale/ensemble/random_forest_classifier'
|
47
|
+
require 'rumale/ensemble/random_forest_regressor'
|
48
|
+
require 'rumale/clustering/k_means'
|
49
|
+
require 'rumale/clustering/dbscan'
|
50
|
+
require 'rumale/decomposition/pca'
|
51
|
+
require 'rumale/decomposition/nmf'
|
52
|
+
require 'rumale/preprocessing/l2_normalizer'
|
53
|
+
require 'rumale/preprocessing/min_max_scaler'
|
54
|
+
require 'rumale/preprocessing/standard_scaler'
|
55
|
+
require 'rumale/preprocessing/label_encoder'
|
56
|
+
require 'rumale/preprocessing/one_hot_encoder'
|
57
|
+
require 'rumale/model_selection/k_fold'
|
58
|
+
require 'rumale/model_selection/stratified_k_fold'
|
59
|
+
require 'rumale/model_selection/cross_validation'
|
60
|
+
require 'rumale/model_selection/grid_search_cv'
|
61
|
+
require 'rumale/evaluation_measure/accuracy'
|
62
|
+
require 'rumale/evaluation_measure/precision'
|
63
|
+
require 'rumale/evaluation_measure/recall'
|
64
|
+
require 'rumale/evaluation_measure/f_score'
|
65
|
+
require 'rumale/evaluation_measure/log_loss'
|
66
|
+
require 'rumale/evaluation_measure/r2_score'
|
67
|
+
require 'rumale/evaluation_measure/mean_squared_error'
|
68
|
+
require 'rumale/evaluation_measure/mean_absolute_error'
|
69
|
+
require 'rumale/evaluation_measure/purity'
|
70
|
+
require 'rumale/evaluation_measure/normalized_mutual_information'
|
@@ -0,0 +1,13 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Rumale
|
4
|
+
# This module consists of basic mix-in classes.
|
5
|
+
module Base
|
6
|
+
# Base module for all estimators in Rumale.
|
7
|
+
module BaseEstimator
|
8
|
+
# Return parameters about an estimator.
|
9
|
+
# @return [Hash]
|
10
|
+
attr_reader :params
|
11
|
+
end
|
12
|
+
end
|
13
|
+
end
|
@@ -0,0 +1,36 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'rumale/validation'
|
4
|
+
require 'rumale/evaluation_measure/accuracy'
|
5
|
+
|
6
|
+
module Rumale
|
7
|
+
module Base
|
8
|
+
# Module for all classifiers in Rumale.
|
9
|
+
module Classifier
|
10
|
+
include Validation
|
11
|
+
|
12
|
+
# An abstract method for fitting a model.
|
13
|
+
def fit
|
14
|
+
raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
|
15
|
+
end
|
16
|
+
|
17
|
+
# An abstract method for predicting labels.
|
18
|
+
def predict
|
19
|
+
raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
|
20
|
+
end
|
21
|
+
|
22
|
+
# Calculate the mean accuracy of the given testing data.
|
23
|
+
#
|
24
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) Testing data.
|
25
|
+
# @param y [Numo::Int32] (shape: [n_samples]) True labels for testing data.
|
26
|
+
# @return [Float] Mean accuracy
|
27
|
+
def score(x, y)
|
28
|
+
check_sample_array(x)
|
29
|
+
check_label_array(y)
|
30
|
+
check_sample_label_size(x, y)
|
31
|
+
evaluator = Rumale::EvaluationMeasure::Accuracy.new
|
32
|
+
evaluator.score(y, predict(x))
|
33
|
+
end
|
34
|
+
end
|
35
|
+
end
|
36
|
+
end
|
@@ -0,0 +1,31 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'rumale/validation'
|
4
|
+
require 'rumale/evaluation_measure/purity'
|
5
|
+
|
6
|
+
module Rumale
|
7
|
+
module Base
|
8
|
+
# Module for all clustering algorithms in Rumale.
|
9
|
+
module ClusterAnalyzer
|
10
|
+
include Validation
|
11
|
+
|
12
|
+
# An abstract method for analyzing clusters and predicting cluster indices.
|
13
|
+
def fit_predict
|
14
|
+
raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
|
15
|
+
end
|
16
|
+
|
17
|
+
# Calculate purity of clustering result.
|
18
|
+
#
|
19
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) Testing data.
|
20
|
+
# @param y [Numo::Int32] (shape: [n_samples]) True labels for testing data.
|
21
|
+
# @return [Float] Purity
|
22
|
+
def score(x, y)
|
23
|
+
check_sample_array(x)
|
24
|
+
check_label_array(y)
|
25
|
+
check_sample_label_size(x, y)
|
26
|
+
evaluator = Rumale::EvaluationMeasure::Purity.new
|
27
|
+
evaluator.score(y, fit_predict(x))
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
31
|
+
end
|
@@ -0,0 +1,17 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'rumale/validation'
|
4
|
+
|
5
|
+
module Rumale
|
6
|
+
module Base
|
7
|
+
# Module for all evaluation measures in Rumale.
|
8
|
+
module Evaluator
|
9
|
+
include Validation
|
10
|
+
|
11
|
+
# An abstract method for evaluation of model.
|
12
|
+
def score
|
13
|
+
raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
|
14
|
+
end
|
15
|
+
end
|
16
|
+
end
|
17
|
+
end
|
@@ -0,0 +1,36 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'rumale/validation'
|
4
|
+
require 'rumale/evaluation_measure/r2_score'
|
5
|
+
|
6
|
+
module Rumale
|
7
|
+
module Base
|
8
|
+
# Module for all regressors in Rumale.
|
9
|
+
module Regressor
|
10
|
+
include Validation
|
11
|
+
|
12
|
+
# An abstract method for fitting a model.
|
13
|
+
def fit
|
14
|
+
raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
|
15
|
+
end
|
16
|
+
|
17
|
+
# An abstract method for predicting labels.
|
18
|
+
def predict
|
19
|
+
raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
|
20
|
+
end
|
21
|
+
|
22
|
+
# Calculate the coefficient of determination for the given testing data.
|
23
|
+
#
|
24
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) Testing data.
|
25
|
+
# @param y [Numo::DFloat] (shape: [n_samples, n_outputs]) Target values for testing data.
|
26
|
+
# @return [Float] Coefficient of determination
|
27
|
+
def score(x, y)
|
28
|
+
check_sample_array(x)
|
29
|
+
check_tvalue_array(y)
|
30
|
+
check_sample_tvalue_size(x, y)
|
31
|
+
evaluator = Rumale::EvaluationMeasure::R2Score.new
|
32
|
+
evaluator.score(y, predict(x))
|
33
|
+
end
|
34
|
+
end
|
35
|
+
end
|
36
|
+
end
|
@@ -0,0 +1,21 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'rumale/validation'
|
4
|
+
|
5
|
+
module Rumale
|
6
|
+
module Base
|
7
|
+
# Module for all validation methods in Rumale.
|
8
|
+
module Splitter
|
9
|
+
include Validation
|
10
|
+
|
11
|
+
# Return the number of splits.
|
12
|
+
# @return [Integer]
|
13
|
+
attr_reader :n_splits
|
14
|
+
|
15
|
+
# An abstract method for splitting dataset.
|
16
|
+
def split
|
17
|
+
raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
|
18
|
+
end
|
19
|
+
end
|
20
|
+
end
|
21
|
+
end
|
@@ -0,0 +1,22 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'rumale/validation'
|
4
|
+
|
5
|
+
module Rumale
|
6
|
+
module Base
|
7
|
+
# Module for all transfomers in Rumale.
|
8
|
+
module Transformer
|
9
|
+
include Validation
|
10
|
+
|
11
|
+
# An abstract method for fitting a model.
|
12
|
+
def fit
|
13
|
+
raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
|
14
|
+
end
|
15
|
+
|
16
|
+
# An abstract method for fitting a model and transforming given data.
|
17
|
+
def fit_transform
|
18
|
+
raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
|
19
|
+
end
|
20
|
+
end
|
21
|
+
end
|
22
|
+
end
|
@@ -0,0 +1,125 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'rumale/base/base_estimator'
|
4
|
+
require 'rumale/base/cluster_analyzer'
|
5
|
+
require 'rumale/pairwise_metric'
|
6
|
+
|
7
|
+
module Rumale
|
8
|
+
module Clustering
|
9
|
+
# DBSCAN is a class that implements DBSCAN cluster analysis.
|
10
|
+
# The current implementation uses the Euclidean distance for analyzing the clusters.
|
11
|
+
#
|
12
|
+
# @example
|
13
|
+
# analyzer = Rumale::Clustering::DBSCAN.new(eps: 0.5, min_samples: 5)
|
14
|
+
# cluster_labels = analyzer.fit_predict(samples)
|
15
|
+
#
|
16
|
+
# *Reference*
|
17
|
+
# - M. Ester, H-P. Kriegel, J. Sander, and X. Xu, "A density-based algorithm for discovering clusters in large spatial databases with noise," Proc. KDD' 96, pp. 266--231, 1996.
|
18
|
+
class DBSCAN
|
19
|
+
include Base::BaseEstimator
|
20
|
+
include Base::ClusterAnalyzer
|
21
|
+
|
22
|
+
# Return the core sample indices.
|
23
|
+
# @return [Numo::Int32] (shape: [n_core_samples])
|
24
|
+
attr_reader :core_sample_ids
|
25
|
+
|
26
|
+
# Return the cluster labels. The negative cluster label indicates that the point is noise.
|
27
|
+
# @return [Numo::Int32] (shape: [n_samples])
|
28
|
+
attr_reader :labels
|
29
|
+
|
30
|
+
# Create a new cluster analyzer with DBSCAN method.
|
31
|
+
#
|
32
|
+
# @param eps [Float] The radius of neighborhood.
|
33
|
+
# @param min_samples [Integer] The number of neighbor samples to be used for the criterion whether a point is a core point.
|
34
|
+
def initialize(eps: 0.5, min_samples: 5)
|
35
|
+
check_params_float(eps: eps)
|
36
|
+
check_params_integer(min_samples: min_samples)
|
37
|
+
@params = {}
|
38
|
+
@params[:eps] = eps
|
39
|
+
@params[:min_samples] = min_samples
|
40
|
+
@core_sample_ids = nil
|
41
|
+
@labels = nil
|
42
|
+
end
|
43
|
+
|
44
|
+
# Analysis clusters with given training data.
|
45
|
+
#
|
46
|
+
# @overload fit(x) -> DBSCAN
|
47
|
+
#
|
48
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
|
49
|
+
# @return [DBSCAN] The learned cluster analyzer itself.
|
50
|
+
def fit(x, _y = nil)
|
51
|
+
check_sample_array(x)
|
52
|
+
partial_fit(x)
|
53
|
+
self
|
54
|
+
end
|
55
|
+
|
56
|
+
# Analysis clusters and assign samples to clusters.
|
57
|
+
#
|
58
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
|
59
|
+
# @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
|
60
|
+
def fit_predict(x)
|
61
|
+
check_sample_array(x)
|
62
|
+
partial_fit(x)
|
63
|
+
labels
|
64
|
+
end
|
65
|
+
|
66
|
+
# Dump marshal data.
|
67
|
+
# @return [Hash] The marshal data.
|
68
|
+
def marshal_dump
|
69
|
+
{ params: @params,
|
70
|
+
core_sample_ids: @core_sample_ids,
|
71
|
+
labels: @labels }
|
72
|
+
end
|
73
|
+
|
74
|
+
# Load marshal data.
|
75
|
+
# @return [nil]
|
76
|
+
def marshal_load(obj)
|
77
|
+
@params = obj[:params]
|
78
|
+
@core_sample_ids = obj[:core_sample_ids]
|
79
|
+
@labels = obj[:labels]
|
80
|
+
nil
|
81
|
+
end
|
82
|
+
|
83
|
+
private
|
84
|
+
|
85
|
+
def partial_fit(x)
|
86
|
+
cluster_id = 0
|
87
|
+
n_samples = x.shape[0]
|
88
|
+
@core_sample_ids = []
|
89
|
+
@labels = Numo::Int32.zeros(n_samples) - 2
|
90
|
+
n_samples.times do |q|
|
91
|
+
next if @labels[q] >= -1
|
92
|
+
cluster_id += 1 if expand_cluster(x, q, cluster_id)
|
93
|
+
end
|
94
|
+
@core_sample_ids = Numo::Int32[*@core_sample_ids.flatten]
|
95
|
+
nil
|
96
|
+
end
|
97
|
+
|
98
|
+
def expand_cluster(x, query_id, cluster_id)
|
99
|
+
target_ids = region_query(x[query_id, true], x)
|
100
|
+
if target_ids.size < @params[:min_samples]
|
101
|
+
@labels[query_id] = -1
|
102
|
+
false
|
103
|
+
else
|
104
|
+
@labels[target_ids] = cluster_id
|
105
|
+
@core_sample_ids.push(target_ids.dup)
|
106
|
+
target_ids.delete(query_id)
|
107
|
+
while (m = target_ids.shift)
|
108
|
+
neighbor_ids = region_query(x[m, true], x)
|
109
|
+
next if neighbor_ids.size < @params[:min_samples]
|
110
|
+
neighbor_ids.each do |n|
|
111
|
+
target_ids.push(n) if @labels[n] < -1
|
112
|
+
@labels[n] = cluster_id if @labels[n] <= -1
|
113
|
+
end
|
114
|
+
end
|
115
|
+
true
|
116
|
+
end
|
117
|
+
end
|
118
|
+
|
119
|
+
def region_query(query, targets)
|
120
|
+
distance_arr = PairwiseMetric.euclidean_distance(query.expand_dims(0), targets)[0, true]
|
121
|
+
distance_arr.lt(@params[:eps]).where.to_a
|
122
|
+
end
|
123
|
+
end
|
124
|
+
end
|
125
|
+
end
|
@@ -0,0 +1,138 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'rumale/base/base_estimator'
|
4
|
+
require 'rumale/base/cluster_analyzer'
|
5
|
+
require 'rumale/pairwise_metric'
|
6
|
+
|
7
|
+
module Rumale
|
8
|
+
# This module consists of classes that implement cluster analysis methods.
|
9
|
+
module Clustering
|
10
|
+
# KMeans is a class that implements K-Means cluster analysis.
|
11
|
+
# The current implementation uses the Euclidean distance for analyzing the clusters.
|
12
|
+
#
|
13
|
+
# @example
|
14
|
+
# analyzer = Rumale::Clustering::KMeans.new(n_clusters: 10, max_iter: 50)
|
15
|
+
# cluster_labels = analyzer.fit_predict(samples)
|
16
|
+
#
|
17
|
+
# *Reference*
|
18
|
+
# - D. Arthur and S. Vassilvitskii, "k-means++: the advantages of careful seeding," Proc. SODA'07, pp. 1027--1035, 2007.
|
19
|
+
class KMeans
|
20
|
+
include Base::BaseEstimator
|
21
|
+
include Base::ClusterAnalyzer
|
22
|
+
|
23
|
+
# Return the centroids.
|
24
|
+
# @return [Numo::DFloat] (shape: [n_clusters, n_features])
|
25
|
+
attr_reader :cluster_centers
|
26
|
+
|
27
|
+
# Return the random generator.
|
28
|
+
# @return [Random]
|
29
|
+
attr_reader :rng
|
30
|
+
|
31
|
+
# Create a new cluster analyzer with K-Means method.
|
32
|
+
#
|
33
|
+
# @param n_clusters [Integer] The number of clusters.
|
34
|
+
# @param init [String] The initialization method for centroids ('random' or 'k-means++').
|
35
|
+
# @param max_iter [Integer] The maximum number of iterations.
|
36
|
+
# @param tol [Float] The tolerance of termination criterion.
|
37
|
+
# @param random_seed [Integer] The seed value using to initialize the random generator.
|
38
|
+
def initialize(n_clusters: 8, init: 'k-means++', max_iter: 50, tol: 1.0e-4, random_seed: nil)
|
39
|
+
check_params_integer(n_clusters: n_clusters, max_iter: max_iter)
|
40
|
+
check_params_float(tol: tol)
|
41
|
+
check_params_string(init: init)
|
42
|
+
check_params_type_or_nil(Integer, random_seed: random_seed)
|
43
|
+
check_params_positive(n_clusters: n_clusters, max_iter: max_iter)
|
44
|
+
@params = {}
|
45
|
+
@params[:n_clusters] = n_clusters
|
46
|
+
@params[:init] = init == 'random' ? 'random' : 'k-means++'
|
47
|
+
@params[:max_iter] = max_iter
|
48
|
+
@params[:tol] = tol
|
49
|
+
@params[:random_seed] = random_seed
|
50
|
+
@params[:random_seed] ||= srand
|
51
|
+
@cluster_centers = nil
|
52
|
+
@rng = Random.new(@params[:random_seed])
|
53
|
+
end
|
54
|
+
|
55
|
+
# Analysis clusters with given training data.
|
56
|
+
#
|
57
|
+
# @overload fit(x) -> KMeans
|
58
|
+
#
|
59
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
|
60
|
+
# @return [KMeans] The learned cluster analyzer itself.
|
61
|
+
def fit(x, _y = nil)
|
62
|
+
check_sample_array(x)
|
63
|
+
init_cluster_centers(x)
|
64
|
+
@params[:max_iter].times do |_t|
|
65
|
+
cluster_labels = assign_cluster(x)
|
66
|
+
old_centers = @cluster_centers.dup
|
67
|
+
@params[:n_clusters].times do |n|
|
68
|
+
assigned_bits = cluster_labels.eq(n)
|
69
|
+
@cluster_centers[n, true] = x[assigned_bits.where, true].mean(axis: 0) if assigned_bits.count.positive?
|
70
|
+
end
|
71
|
+
error = Numo::NMath.sqrt(((old_centers - @cluster_centers)**2).sum(axis: 1)).mean
|
72
|
+
break if error <= @params[:tol]
|
73
|
+
end
|
74
|
+
self
|
75
|
+
end
|
76
|
+
|
77
|
+
# Predict cluster labels for samples.
|
78
|
+
#
|
79
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the cluster label.
|
80
|
+
# @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
|
81
|
+
def predict(x)
|
82
|
+
check_sample_array(x)
|
83
|
+
assign_cluster(x)
|
84
|
+
end
|
85
|
+
|
86
|
+
# Analysis clusters and assign samples to clusters.
|
87
|
+
#
|
88
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
|
89
|
+
# @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
|
90
|
+
def fit_predict(x)
|
91
|
+
check_sample_array(x)
|
92
|
+
fit(x)
|
93
|
+
predict(x)
|
94
|
+
end
|
95
|
+
|
96
|
+
# Dump marshal data.
|
97
|
+
# @return [Hash] The marshal data.
|
98
|
+
def marshal_dump
|
99
|
+
{ params: @params,
|
100
|
+
cluster_centers: @cluster_centers,
|
101
|
+
rng: @rng }
|
102
|
+
end
|
103
|
+
|
104
|
+
# Load marshal data.
|
105
|
+
# @return [nil]
|
106
|
+
def marshal_load(obj)
|
107
|
+
@params = obj[:params]
|
108
|
+
@cluster_centers = obj[:cluster_centers]
|
109
|
+
@rng = obj[:rng]
|
110
|
+
nil
|
111
|
+
end
|
112
|
+
|
113
|
+
private
|
114
|
+
|
115
|
+
def assign_cluster(x)
|
116
|
+
distance_matrix = PairwiseMetric.euclidean_distance(x, @cluster_centers)
|
117
|
+
distance_matrix.min_index(axis: 1) - Numo::Int32[*0.step(distance_matrix.size - 1, @cluster_centers.shape[0])]
|
118
|
+
end
|
119
|
+
|
120
|
+
def init_cluster_centers(x)
|
121
|
+
# random initialize
|
122
|
+
n_samples = x.shape[0]
|
123
|
+
rand_id = [*0...n_samples].sample(@params[:n_clusters], random: @rng)
|
124
|
+
@cluster_centers = x[rand_id, true].dup
|
125
|
+
return unless @params[:init] == 'k-means++'
|
126
|
+
# k-means++ initialize
|
127
|
+
(1...@params[:n_clusters]).each do |n|
|
128
|
+
distance_matrix = PairwiseMetric.euclidean_distance(x, @cluster_centers[0...n, true])
|
129
|
+
min_distances = distance_matrix.flatten[distance_matrix.min_index(axis: 1)]
|
130
|
+
probs = min_distances**2 / (min_distances**2).sum
|
131
|
+
cum_probs = probs.cumsum
|
132
|
+
selected_id = cum_probs.gt(@rng.rand).where.to_a.first
|
133
|
+
@cluster_centers[n, true] = x[selected_id, true].dup
|
134
|
+
end
|
135
|
+
end
|
136
|
+
end
|
137
|
+
end
|
138
|
+
end
|