rumale 0.8.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/.coveralls.yml +1 -0
- data/.gitignore +20 -0
- data/.rspec +3 -0
- data/.rubocop.yml +47 -0
- data/.rubocop_todo.yml +58 -0
- data/.travis.yml +13 -0
- data/CHANGELOG.md +2 -0
- data/CODE_OF_CONDUCT.md +74 -0
- data/Gemfile +4 -0
- data/LICENSE.txt +23 -0
- data/README.md +175 -0
- data/Rakefile +6 -0
- data/bin/console +14 -0
- data/bin/setup +8 -0
- data/lib/rumale.rb +70 -0
- data/lib/rumale/base/base_estimator.rb +13 -0
- data/lib/rumale/base/classifier.rb +36 -0
- data/lib/rumale/base/cluster_analyzer.rb +31 -0
- data/lib/rumale/base/evaluator.rb +17 -0
- data/lib/rumale/base/regressor.rb +36 -0
- data/lib/rumale/base/splitter.rb +21 -0
- data/lib/rumale/base/transformer.rb +22 -0
- data/lib/rumale/clustering/dbscan.rb +125 -0
- data/lib/rumale/clustering/k_means.rb +138 -0
- data/lib/rumale/dataset.rb +110 -0
- data/lib/rumale/decomposition/nmf.rb +141 -0
- data/lib/rumale/decomposition/pca.rb +148 -0
- data/lib/rumale/ensemble/ada_boost_classifier.rb +196 -0
- data/lib/rumale/ensemble/ada_boost_regressor.rb +178 -0
- data/lib/rumale/ensemble/random_forest_classifier.rb +180 -0
- data/lib/rumale/ensemble/random_forest_regressor.rb +141 -0
- data/lib/rumale/evaluation_measure/accuracy.rb +29 -0
- data/lib/rumale/evaluation_measure/f_score.rb +50 -0
- data/lib/rumale/evaluation_measure/log_loss.rb +45 -0
- data/lib/rumale/evaluation_measure/mean_absolute_error.rb +29 -0
- data/lib/rumale/evaluation_measure/mean_squared_error.rb +29 -0
- data/lib/rumale/evaluation_measure/normalized_mutual_information.rb +62 -0
- data/lib/rumale/evaluation_measure/precision.rb +50 -0
- data/lib/rumale/evaluation_measure/precision_recall.rb +91 -0
- data/lib/rumale/evaluation_measure/purity.rb +40 -0
- data/lib/rumale/evaluation_measure/r2_score.rb +43 -0
- data/lib/rumale/evaluation_measure/recall.rb +50 -0
- data/lib/rumale/kernel_approximation/rbf.rb +121 -0
- data/lib/rumale/kernel_machine/kernel_svc.rb +193 -0
- data/lib/rumale/linear_model/base_linear_model.rb +89 -0
- data/lib/rumale/linear_model/lasso.rb +136 -0
- data/lib/rumale/linear_model/linear_regression.rb +110 -0
- data/lib/rumale/linear_model/logistic_regression.rb +159 -0
- data/lib/rumale/linear_model/ridge.rb +110 -0
- data/lib/rumale/linear_model/svc.rb +183 -0
- data/lib/rumale/linear_model/svr.rb +122 -0
- data/lib/rumale/model_selection/cross_validation.rb +123 -0
- data/lib/rumale/model_selection/grid_search_cv.rb +247 -0
- data/lib/rumale/model_selection/k_fold.rb +76 -0
- data/lib/rumale/model_selection/stratified_k_fold.rb +94 -0
- data/lib/rumale/multiclass/one_vs_rest_classifier.rb +100 -0
- data/lib/rumale/naive_bayes/naive_bayes.rb +315 -0
- data/lib/rumale/nearest_neighbors/k_neighbors_classifier.rb +111 -0
- data/lib/rumale/nearest_neighbors/k_neighbors_regressor.rb +93 -0
- data/lib/rumale/optimizer/nadam.rb +90 -0
- data/lib/rumale/optimizer/rmsprop.rb +69 -0
- data/lib/rumale/optimizer/sgd.rb +65 -0
- data/lib/rumale/optimizer/yellow_fin.rb +144 -0
- data/lib/rumale/pairwise_metric.rb +91 -0
- data/lib/rumale/pipeline/pipeline.rb +197 -0
- data/lib/rumale/polynomial_model/base_factorization_machine.rb +99 -0
- data/lib/rumale/polynomial_model/factorization_machine_classifier.rb +197 -0
- data/lib/rumale/polynomial_model/factorization_machine_regressor.rb +131 -0
- data/lib/rumale/preprocessing/l2_normalizer.rb +62 -0
- data/lib/rumale/preprocessing/label_encoder.rb +94 -0
- data/lib/rumale/preprocessing/min_max_scaler.rb +92 -0
- data/lib/rumale/preprocessing/one_hot_encoder.rb +98 -0
- data/lib/rumale/preprocessing/standard_scaler.rb +86 -0
- data/lib/rumale/probabilistic_output.rb +112 -0
- data/lib/rumale/tree/base_decision_tree.rb +153 -0
- data/lib/rumale/tree/decision_tree_classifier.rb +163 -0
- data/lib/rumale/tree/decision_tree_regressor.rb +135 -0
- data/lib/rumale/tree/node.rb +70 -0
- data/lib/rumale/utils.rb +37 -0
- data/lib/rumale/validation.rb +79 -0
- data/lib/rumale/values.rb +13 -0
- data/lib/rumale/version.rb +6 -0
- data/rumale.gemspec +41 -0
- metadata +204 -0
data/bin/console
ADDED
@@ -0,0 +1,14 @@
|
|
1
|
+
#!/usr/bin/env ruby
|
2
|
+
|
3
|
+
require 'bundler/setup'
|
4
|
+
require 'rumale'
|
5
|
+
|
6
|
+
# You can add fixtures and/or initialization code here to make experimenting
|
7
|
+
# with your gem easier. You can also use a different console, if you like.
|
8
|
+
|
9
|
+
# (If you use this, don't forget to add pry to your Gemfile!)
|
10
|
+
# require 'pry'
|
11
|
+
# Pry.start
|
12
|
+
|
13
|
+
require 'irb'
|
14
|
+
IRB.start(__FILE__)
|
data/bin/setup
ADDED
data/lib/rumale.rb
ADDED
@@ -0,0 +1,70 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'numo/narray'
|
4
|
+
|
5
|
+
require 'rumale/version'
|
6
|
+
require 'rumale/validation'
|
7
|
+
require 'rumale/values'
|
8
|
+
require 'rumale/utils'
|
9
|
+
require 'rumale/pairwise_metric'
|
10
|
+
require 'rumale/dataset'
|
11
|
+
require 'rumale/probabilistic_output'
|
12
|
+
require 'rumale/base/base_estimator'
|
13
|
+
require 'rumale/base/classifier'
|
14
|
+
require 'rumale/base/regressor'
|
15
|
+
require 'rumale/base/cluster_analyzer'
|
16
|
+
require 'rumale/base/transformer'
|
17
|
+
require 'rumale/base/splitter'
|
18
|
+
require 'rumale/base/evaluator'
|
19
|
+
require 'rumale/optimizer/sgd'
|
20
|
+
require 'rumale/optimizer/rmsprop'
|
21
|
+
require 'rumale/optimizer/nadam'
|
22
|
+
require 'rumale/optimizer/yellow_fin'
|
23
|
+
require 'rumale/pipeline/pipeline'
|
24
|
+
require 'rumale/kernel_approximation/rbf'
|
25
|
+
require 'rumale/linear_model/base_linear_model'
|
26
|
+
require 'rumale/linear_model/svc'
|
27
|
+
require 'rumale/linear_model/svr'
|
28
|
+
require 'rumale/linear_model/logistic_regression'
|
29
|
+
require 'rumale/linear_model/linear_regression'
|
30
|
+
require 'rumale/linear_model/ridge'
|
31
|
+
require 'rumale/linear_model/lasso'
|
32
|
+
require 'rumale/kernel_machine/kernel_svc'
|
33
|
+
require 'rumale/polynomial_model/base_factorization_machine'
|
34
|
+
require 'rumale/polynomial_model/factorization_machine_classifier'
|
35
|
+
require 'rumale/polynomial_model/factorization_machine_regressor'
|
36
|
+
require 'rumale/multiclass/one_vs_rest_classifier'
|
37
|
+
require 'rumale/nearest_neighbors/k_neighbors_classifier'
|
38
|
+
require 'rumale/nearest_neighbors/k_neighbors_regressor'
|
39
|
+
require 'rumale/naive_bayes/naive_bayes'
|
40
|
+
require 'rumale/tree/node'
|
41
|
+
require 'rumale/tree/base_decision_tree'
|
42
|
+
require 'rumale/tree/decision_tree_classifier'
|
43
|
+
require 'rumale/tree/decision_tree_regressor'
|
44
|
+
require 'rumale/ensemble/ada_boost_classifier'
|
45
|
+
require 'rumale/ensemble/ada_boost_regressor'
|
46
|
+
require 'rumale/ensemble/random_forest_classifier'
|
47
|
+
require 'rumale/ensemble/random_forest_regressor'
|
48
|
+
require 'rumale/clustering/k_means'
|
49
|
+
require 'rumale/clustering/dbscan'
|
50
|
+
require 'rumale/decomposition/pca'
|
51
|
+
require 'rumale/decomposition/nmf'
|
52
|
+
require 'rumale/preprocessing/l2_normalizer'
|
53
|
+
require 'rumale/preprocessing/min_max_scaler'
|
54
|
+
require 'rumale/preprocessing/standard_scaler'
|
55
|
+
require 'rumale/preprocessing/label_encoder'
|
56
|
+
require 'rumale/preprocessing/one_hot_encoder'
|
57
|
+
require 'rumale/model_selection/k_fold'
|
58
|
+
require 'rumale/model_selection/stratified_k_fold'
|
59
|
+
require 'rumale/model_selection/cross_validation'
|
60
|
+
require 'rumale/model_selection/grid_search_cv'
|
61
|
+
require 'rumale/evaluation_measure/accuracy'
|
62
|
+
require 'rumale/evaluation_measure/precision'
|
63
|
+
require 'rumale/evaluation_measure/recall'
|
64
|
+
require 'rumale/evaluation_measure/f_score'
|
65
|
+
require 'rumale/evaluation_measure/log_loss'
|
66
|
+
require 'rumale/evaluation_measure/r2_score'
|
67
|
+
require 'rumale/evaluation_measure/mean_squared_error'
|
68
|
+
require 'rumale/evaluation_measure/mean_absolute_error'
|
69
|
+
require 'rumale/evaluation_measure/purity'
|
70
|
+
require 'rumale/evaluation_measure/normalized_mutual_information'
|
@@ -0,0 +1,13 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Rumale
|
4
|
+
# This module consists of basic mix-in classes.
|
5
|
+
module Base
|
6
|
+
# Base module for all estimators in Rumale.
|
7
|
+
module BaseEstimator
|
8
|
+
# Return parameters about an estimator.
|
9
|
+
# @return [Hash]
|
10
|
+
attr_reader :params
|
11
|
+
end
|
12
|
+
end
|
13
|
+
end
|
@@ -0,0 +1,36 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'rumale/validation'
|
4
|
+
require 'rumale/evaluation_measure/accuracy'
|
5
|
+
|
6
|
+
module Rumale
|
7
|
+
module Base
|
8
|
+
# Module for all classifiers in Rumale.
|
9
|
+
module Classifier
|
10
|
+
include Validation
|
11
|
+
|
12
|
+
# An abstract method for fitting a model.
|
13
|
+
def fit
|
14
|
+
raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
|
15
|
+
end
|
16
|
+
|
17
|
+
# An abstract method for predicting labels.
|
18
|
+
def predict
|
19
|
+
raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
|
20
|
+
end
|
21
|
+
|
22
|
+
# Calculate the mean accuracy of the given testing data.
|
23
|
+
#
|
24
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) Testing data.
|
25
|
+
# @param y [Numo::Int32] (shape: [n_samples]) True labels for testing data.
|
26
|
+
# @return [Float] Mean accuracy
|
27
|
+
def score(x, y)
|
28
|
+
check_sample_array(x)
|
29
|
+
check_label_array(y)
|
30
|
+
check_sample_label_size(x, y)
|
31
|
+
evaluator = Rumale::EvaluationMeasure::Accuracy.new
|
32
|
+
evaluator.score(y, predict(x))
|
33
|
+
end
|
34
|
+
end
|
35
|
+
end
|
36
|
+
end
|
@@ -0,0 +1,31 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'rumale/validation'
|
4
|
+
require 'rumale/evaluation_measure/purity'
|
5
|
+
|
6
|
+
module Rumale
|
7
|
+
module Base
|
8
|
+
# Module for all clustering algorithms in Rumale.
|
9
|
+
module ClusterAnalyzer
|
10
|
+
include Validation
|
11
|
+
|
12
|
+
# An abstract method for analyzing clusters and predicting cluster indices.
|
13
|
+
def fit_predict
|
14
|
+
raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
|
15
|
+
end
|
16
|
+
|
17
|
+
# Calculate purity of clustering result.
|
18
|
+
#
|
19
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) Testing data.
|
20
|
+
# @param y [Numo::Int32] (shape: [n_samples]) True labels for testing data.
|
21
|
+
# @return [Float] Purity
|
22
|
+
def score(x, y)
|
23
|
+
check_sample_array(x)
|
24
|
+
check_label_array(y)
|
25
|
+
check_sample_label_size(x, y)
|
26
|
+
evaluator = Rumale::EvaluationMeasure::Purity.new
|
27
|
+
evaluator.score(y, fit_predict(x))
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
31
|
+
end
|
@@ -0,0 +1,17 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'rumale/validation'
|
4
|
+
|
5
|
+
module Rumale
|
6
|
+
module Base
|
7
|
+
# Module for all evaluation measures in Rumale.
|
8
|
+
module Evaluator
|
9
|
+
include Validation
|
10
|
+
|
11
|
+
# An abstract method for evaluation of model.
|
12
|
+
def score
|
13
|
+
raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
|
14
|
+
end
|
15
|
+
end
|
16
|
+
end
|
17
|
+
end
|
@@ -0,0 +1,36 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'rumale/validation'
|
4
|
+
require 'rumale/evaluation_measure/r2_score'
|
5
|
+
|
6
|
+
module Rumale
|
7
|
+
module Base
|
8
|
+
# Module for all regressors in Rumale.
|
9
|
+
module Regressor
|
10
|
+
include Validation
|
11
|
+
|
12
|
+
# An abstract method for fitting a model.
|
13
|
+
def fit
|
14
|
+
raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
|
15
|
+
end
|
16
|
+
|
17
|
+
# An abstract method for predicting labels.
|
18
|
+
def predict
|
19
|
+
raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
|
20
|
+
end
|
21
|
+
|
22
|
+
# Calculate the coefficient of determination for the given testing data.
|
23
|
+
#
|
24
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) Testing data.
|
25
|
+
# @param y [Numo::DFloat] (shape: [n_samples, n_outputs]) Target values for testing data.
|
26
|
+
# @return [Float] Coefficient of determination
|
27
|
+
def score(x, y)
|
28
|
+
check_sample_array(x)
|
29
|
+
check_tvalue_array(y)
|
30
|
+
check_sample_tvalue_size(x, y)
|
31
|
+
evaluator = Rumale::EvaluationMeasure::R2Score.new
|
32
|
+
evaluator.score(y, predict(x))
|
33
|
+
end
|
34
|
+
end
|
35
|
+
end
|
36
|
+
end
|
@@ -0,0 +1,21 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'rumale/validation'
|
4
|
+
|
5
|
+
module Rumale
|
6
|
+
module Base
|
7
|
+
# Module for all validation methods in Rumale.
|
8
|
+
module Splitter
|
9
|
+
include Validation
|
10
|
+
|
11
|
+
# Return the number of splits.
|
12
|
+
# @return [Integer]
|
13
|
+
attr_reader :n_splits
|
14
|
+
|
15
|
+
# An abstract method for splitting dataset.
|
16
|
+
def split
|
17
|
+
raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
|
18
|
+
end
|
19
|
+
end
|
20
|
+
end
|
21
|
+
end
|
@@ -0,0 +1,22 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'rumale/validation'
|
4
|
+
|
5
|
+
module Rumale
|
6
|
+
module Base
|
7
|
+
# Module for all transfomers in Rumale.
|
8
|
+
module Transformer
|
9
|
+
include Validation
|
10
|
+
|
11
|
+
# An abstract method for fitting a model.
|
12
|
+
def fit
|
13
|
+
raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
|
14
|
+
end
|
15
|
+
|
16
|
+
# An abstract method for fitting a model and transforming given data.
|
17
|
+
def fit_transform
|
18
|
+
raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
|
19
|
+
end
|
20
|
+
end
|
21
|
+
end
|
22
|
+
end
|
@@ -0,0 +1,125 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'rumale/base/base_estimator'
|
4
|
+
require 'rumale/base/cluster_analyzer'
|
5
|
+
require 'rumale/pairwise_metric'
|
6
|
+
|
7
|
+
module Rumale
|
8
|
+
module Clustering
|
9
|
+
# DBSCAN is a class that implements DBSCAN cluster analysis.
|
10
|
+
# The current implementation uses the Euclidean distance for analyzing the clusters.
|
11
|
+
#
|
12
|
+
# @example
|
13
|
+
# analyzer = Rumale::Clustering::DBSCAN.new(eps: 0.5, min_samples: 5)
|
14
|
+
# cluster_labels = analyzer.fit_predict(samples)
|
15
|
+
#
|
16
|
+
# *Reference*
|
17
|
+
# - M. Ester, H-P. Kriegel, J. Sander, and X. Xu, "A density-based algorithm for discovering clusters in large spatial databases with noise," Proc. KDD' 96, pp. 266--231, 1996.
|
18
|
+
class DBSCAN
|
19
|
+
include Base::BaseEstimator
|
20
|
+
include Base::ClusterAnalyzer
|
21
|
+
|
22
|
+
# Return the core sample indices.
|
23
|
+
# @return [Numo::Int32] (shape: [n_core_samples])
|
24
|
+
attr_reader :core_sample_ids
|
25
|
+
|
26
|
+
# Return the cluster labels. The negative cluster label indicates that the point is noise.
|
27
|
+
# @return [Numo::Int32] (shape: [n_samples])
|
28
|
+
attr_reader :labels
|
29
|
+
|
30
|
+
# Create a new cluster analyzer with DBSCAN method.
|
31
|
+
#
|
32
|
+
# @param eps [Float] The radius of neighborhood.
|
33
|
+
# @param min_samples [Integer] The number of neighbor samples to be used for the criterion whether a point is a core point.
|
34
|
+
def initialize(eps: 0.5, min_samples: 5)
|
35
|
+
check_params_float(eps: eps)
|
36
|
+
check_params_integer(min_samples: min_samples)
|
37
|
+
@params = {}
|
38
|
+
@params[:eps] = eps
|
39
|
+
@params[:min_samples] = min_samples
|
40
|
+
@core_sample_ids = nil
|
41
|
+
@labels = nil
|
42
|
+
end
|
43
|
+
|
44
|
+
# Analysis clusters with given training data.
|
45
|
+
#
|
46
|
+
# @overload fit(x) -> DBSCAN
|
47
|
+
#
|
48
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
|
49
|
+
# @return [DBSCAN] The learned cluster analyzer itself.
|
50
|
+
def fit(x, _y = nil)
|
51
|
+
check_sample_array(x)
|
52
|
+
partial_fit(x)
|
53
|
+
self
|
54
|
+
end
|
55
|
+
|
56
|
+
# Analysis clusters and assign samples to clusters.
|
57
|
+
#
|
58
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
|
59
|
+
# @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
|
60
|
+
def fit_predict(x)
|
61
|
+
check_sample_array(x)
|
62
|
+
partial_fit(x)
|
63
|
+
labels
|
64
|
+
end
|
65
|
+
|
66
|
+
# Dump marshal data.
|
67
|
+
# @return [Hash] The marshal data.
|
68
|
+
def marshal_dump
|
69
|
+
{ params: @params,
|
70
|
+
core_sample_ids: @core_sample_ids,
|
71
|
+
labels: @labels }
|
72
|
+
end
|
73
|
+
|
74
|
+
# Load marshal data.
|
75
|
+
# @return [nil]
|
76
|
+
def marshal_load(obj)
|
77
|
+
@params = obj[:params]
|
78
|
+
@core_sample_ids = obj[:core_sample_ids]
|
79
|
+
@labels = obj[:labels]
|
80
|
+
nil
|
81
|
+
end
|
82
|
+
|
83
|
+
private
|
84
|
+
|
85
|
+
def partial_fit(x)
|
86
|
+
cluster_id = 0
|
87
|
+
n_samples = x.shape[0]
|
88
|
+
@core_sample_ids = []
|
89
|
+
@labels = Numo::Int32.zeros(n_samples) - 2
|
90
|
+
n_samples.times do |q|
|
91
|
+
next if @labels[q] >= -1
|
92
|
+
cluster_id += 1 if expand_cluster(x, q, cluster_id)
|
93
|
+
end
|
94
|
+
@core_sample_ids = Numo::Int32[*@core_sample_ids.flatten]
|
95
|
+
nil
|
96
|
+
end
|
97
|
+
|
98
|
+
def expand_cluster(x, query_id, cluster_id)
|
99
|
+
target_ids = region_query(x[query_id, true], x)
|
100
|
+
if target_ids.size < @params[:min_samples]
|
101
|
+
@labels[query_id] = -1
|
102
|
+
false
|
103
|
+
else
|
104
|
+
@labels[target_ids] = cluster_id
|
105
|
+
@core_sample_ids.push(target_ids.dup)
|
106
|
+
target_ids.delete(query_id)
|
107
|
+
while (m = target_ids.shift)
|
108
|
+
neighbor_ids = region_query(x[m, true], x)
|
109
|
+
next if neighbor_ids.size < @params[:min_samples]
|
110
|
+
neighbor_ids.each do |n|
|
111
|
+
target_ids.push(n) if @labels[n] < -1
|
112
|
+
@labels[n] = cluster_id if @labels[n] <= -1
|
113
|
+
end
|
114
|
+
end
|
115
|
+
true
|
116
|
+
end
|
117
|
+
end
|
118
|
+
|
119
|
+
def region_query(query, targets)
|
120
|
+
distance_arr = PairwiseMetric.euclidean_distance(query.expand_dims(0), targets)[0, true]
|
121
|
+
distance_arr.lt(@params[:eps]).where.to_a
|
122
|
+
end
|
123
|
+
end
|
124
|
+
end
|
125
|
+
end
|
@@ -0,0 +1,138 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'rumale/base/base_estimator'
|
4
|
+
require 'rumale/base/cluster_analyzer'
|
5
|
+
require 'rumale/pairwise_metric'
|
6
|
+
|
7
|
+
module Rumale
|
8
|
+
# This module consists of classes that implement cluster analysis methods.
|
9
|
+
module Clustering
|
10
|
+
# KMeans is a class that implements K-Means cluster analysis.
|
11
|
+
# The current implementation uses the Euclidean distance for analyzing the clusters.
|
12
|
+
#
|
13
|
+
# @example
|
14
|
+
# analyzer = Rumale::Clustering::KMeans.new(n_clusters: 10, max_iter: 50)
|
15
|
+
# cluster_labels = analyzer.fit_predict(samples)
|
16
|
+
#
|
17
|
+
# *Reference*
|
18
|
+
# - D. Arthur and S. Vassilvitskii, "k-means++: the advantages of careful seeding," Proc. SODA'07, pp. 1027--1035, 2007.
|
19
|
+
class KMeans
|
20
|
+
include Base::BaseEstimator
|
21
|
+
include Base::ClusterAnalyzer
|
22
|
+
|
23
|
+
# Return the centroids.
|
24
|
+
# @return [Numo::DFloat] (shape: [n_clusters, n_features])
|
25
|
+
attr_reader :cluster_centers
|
26
|
+
|
27
|
+
# Return the random generator.
|
28
|
+
# @return [Random]
|
29
|
+
attr_reader :rng
|
30
|
+
|
31
|
+
# Create a new cluster analyzer with K-Means method.
|
32
|
+
#
|
33
|
+
# @param n_clusters [Integer] The number of clusters.
|
34
|
+
# @param init [String] The initialization method for centroids ('random' or 'k-means++').
|
35
|
+
# @param max_iter [Integer] The maximum number of iterations.
|
36
|
+
# @param tol [Float] The tolerance of termination criterion.
|
37
|
+
# @param random_seed [Integer] The seed value using to initialize the random generator.
|
38
|
+
def initialize(n_clusters: 8, init: 'k-means++', max_iter: 50, tol: 1.0e-4, random_seed: nil)
|
39
|
+
check_params_integer(n_clusters: n_clusters, max_iter: max_iter)
|
40
|
+
check_params_float(tol: tol)
|
41
|
+
check_params_string(init: init)
|
42
|
+
check_params_type_or_nil(Integer, random_seed: random_seed)
|
43
|
+
check_params_positive(n_clusters: n_clusters, max_iter: max_iter)
|
44
|
+
@params = {}
|
45
|
+
@params[:n_clusters] = n_clusters
|
46
|
+
@params[:init] = init == 'random' ? 'random' : 'k-means++'
|
47
|
+
@params[:max_iter] = max_iter
|
48
|
+
@params[:tol] = tol
|
49
|
+
@params[:random_seed] = random_seed
|
50
|
+
@params[:random_seed] ||= srand
|
51
|
+
@cluster_centers = nil
|
52
|
+
@rng = Random.new(@params[:random_seed])
|
53
|
+
end
|
54
|
+
|
55
|
+
# Analysis clusters with given training data.
|
56
|
+
#
|
57
|
+
# @overload fit(x) -> KMeans
|
58
|
+
#
|
59
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
|
60
|
+
# @return [KMeans] The learned cluster analyzer itself.
|
61
|
+
def fit(x, _y = nil)
|
62
|
+
check_sample_array(x)
|
63
|
+
init_cluster_centers(x)
|
64
|
+
@params[:max_iter].times do |_t|
|
65
|
+
cluster_labels = assign_cluster(x)
|
66
|
+
old_centers = @cluster_centers.dup
|
67
|
+
@params[:n_clusters].times do |n|
|
68
|
+
assigned_bits = cluster_labels.eq(n)
|
69
|
+
@cluster_centers[n, true] = x[assigned_bits.where, true].mean(axis: 0) if assigned_bits.count.positive?
|
70
|
+
end
|
71
|
+
error = Numo::NMath.sqrt(((old_centers - @cluster_centers)**2).sum(axis: 1)).mean
|
72
|
+
break if error <= @params[:tol]
|
73
|
+
end
|
74
|
+
self
|
75
|
+
end
|
76
|
+
|
77
|
+
# Predict cluster labels for samples.
|
78
|
+
#
|
79
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the cluster label.
|
80
|
+
# @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
|
81
|
+
def predict(x)
|
82
|
+
check_sample_array(x)
|
83
|
+
assign_cluster(x)
|
84
|
+
end
|
85
|
+
|
86
|
+
# Analysis clusters and assign samples to clusters.
|
87
|
+
#
|
88
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
|
89
|
+
# @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
|
90
|
+
def fit_predict(x)
|
91
|
+
check_sample_array(x)
|
92
|
+
fit(x)
|
93
|
+
predict(x)
|
94
|
+
end
|
95
|
+
|
96
|
+
# Dump marshal data.
|
97
|
+
# @return [Hash] The marshal data.
|
98
|
+
def marshal_dump
|
99
|
+
{ params: @params,
|
100
|
+
cluster_centers: @cluster_centers,
|
101
|
+
rng: @rng }
|
102
|
+
end
|
103
|
+
|
104
|
+
# Load marshal data.
|
105
|
+
# @return [nil]
|
106
|
+
def marshal_load(obj)
|
107
|
+
@params = obj[:params]
|
108
|
+
@cluster_centers = obj[:cluster_centers]
|
109
|
+
@rng = obj[:rng]
|
110
|
+
nil
|
111
|
+
end
|
112
|
+
|
113
|
+
private
|
114
|
+
|
115
|
+
def assign_cluster(x)
|
116
|
+
distance_matrix = PairwiseMetric.euclidean_distance(x, @cluster_centers)
|
117
|
+
distance_matrix.min_index(axis: 1) - Numo::Int32[*0.step(distance_matrix.size - 1, @cluster_centers.shape[0])]
|
118
|
+
end
|
119
|
+
|
120
|
+
def init_cluster_centers(x)
|
121
|
+
# random initialize
|
122
|
+
n_samples = x.shape[0]
|
123
|
+
rand_id = [*0...n_samples].sample(@params[:n_clusters], random: @rng)
|
124
|
+
@cluster_centers = x[rand_id, true].dup
|
125
|
+
return unless @params[:init] == 'k-means++'
|
126
|
+
# k-means++ initialize
|
127
|
+
(1...@params[:n_clusters]).each do |n|
|
128
|
+
distance_matrix = PairwiseMetric.euclidean_distance(x, @cluster_centers[0...n, true])
|
129
|
+
min_distances = distance_matrix.flatten[distance_matrix.min_index(axis: 1)]
|
130
|
+
probs = min_distances**2 / (min_distances**2).sum
|
131
|
+
cum_probs = probs.cumsum
|
132
|
+
selected_id = cum_probs.gt(@rng.rand).where.to_a.first
|
133
|
+
@cluster_centers[n, true] = x[selected_id, true].dup
|
134
|
+
end
|
135
|
+
end
|
136
|
+
end
|
137
|
+
end
|
138
|
+
end
|