rumale 0.23.3 → 0.24.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/LICENSE.txt +5 -1
- data/README.md +3 -288
- data/lib/rumale/version.rb +1 -1
- data/lib/rumale.rb +20 -131
- metadata +252 -150
- data/CHANGELOG.md +0 -643
- data/CODE_OF_CONDUCT.md +0 -74
- data/ext/rumale/extconf.rb +0 -37
- data/ext/rumale/rumaleext.c +0 -545
- data/ext/rumale/rumaleext.h +0 -12
- data/lib/rumale/base/base_estimator.rb +0 -49
- data/lib/rumale/base/classifier.rb +0 -36
- data/lib/rumale/base/cluster_analyzer.rb +0 -31
- data/lib/rumale/base/evaluator.rb +0 -17
- data/lib/rumale/base/regressor.rb +0 -36
- data/lib/rumale/base/splitter.rb +0 -21
- data/lib/rumale/base/transformer.rb +0 -22
- data/lib/rumale/clustering/dbscan.rb +0 -123
- data/lib/rumale/clustering/gaussian_mixture.rb +0 -218
- data/lib/rumale/clustering/hdbscan.rb +0 -291
- data/lib/rumale/clustering/k_means.rb +0 -122
- data/lib/rumale/clustering/k_medoids.rb +0 -141
- data/lib/rumale/clustering/mini_batch_k_means.rb +0 -139
- data/lib/rumale/clustering/power_iteration.rb +0 -127
- data/lib/rumale/clustering/single_linkage.rb +0 -203
- data/lib/rumale/clustering/snn.rb +0 -76
- data/lib/rumale/clustering/spectral_clustering.rb +0 -115
- data/lib/rumale/dataset.rb +0 -246
- data/lib/rumale/decomposition/factor_analysis.rb +0 -150
- data/lib/rumale/decomposition/fast_ica.rb +0 -188
- data/lib/rumale/decomposition/nmf.rb +0 -124
- data/lib/rumale/decomposition/pca.rb +0 -159
- data/lib/rumale/ensemble/ada_boost_classifier.rb +0 -179
- data/lib/rumale/ensemble/ada_boost_regressor.rb +0 -160
- data/lib/rumale/ensemble/extra_trees_classifier.rb +0 -139
- data/lib/rumale/ensemble/extra_trees_regressor.rb +0 -125
- data/lib/rumale/ensemble/gradient_boosting_classifier.rb +0 -306
- data/lib/rumale/ensemble/gradient_boosting_regressor.rb +0 -237
- data/lib/rumale/ensemble/random_forest_classifier.rb +0 -189
- data/lib/rumale/ensemble/random_forest_regressor.rb +0 -153
- data/lib/rumale/ensemble/stacking_classifier.rb +0 -215
- data/lib/rumale/ensemble/stacking_regressor.rb +0 -163
- data/lib/rumale/ensemble/voting_classifier.rb +0 -126
- data/lib/rumale/ensemble/voting_regressor.rb +0 -82
- data/lib/rumale/evaluation_measure/accuracy.rb +0 -29
- data/lib/rumale/evaluation_measure/adjusted_rand_score.rb +0 -74
- data/lib/rumale/evaluation_measure/calinski_harabasz_score.rb +0 -56
- data/lib/rumale/evaluation_measure/davies_bouldin_score.rb +0 -53
- data/lib/rumale/evaluation_measure/explained_variance_score.rb +0 -39
- data/lib/rumale/evaluation_measure/f_score.rb +0 -50
- data/lib/rumale/evaluation_measure/function.rb +0 -147
- data/lib/rumale/evaluation_measure/log_loss.rb +0 -45
- data/lib/rumale/evaluation_measure/mean_absolute_error.rb +0 -29
- data/lib/rumale/evaluation_measure/mean_squared_error.rb +0 -29
- data/lib/rumale/evaluation_measure/mean_squared_log_error.rb +0 -29
- data/lib/rumale/evaluation_measure/median_absolute_error.rb +0 -30
- data/lib/rumale/evaluation_measure/mutual_information.rb +0 -49
- data/lib/rumale/evaluation_measure/normalized_mutual_information.rb +0 -53
- data/lib/rumale/evaluation_measure/precision.rb +0 -50
- data/lib/rumale/evaluation_measure/precision_recall.rb +0 -96
- data/lib/rumale/evaluation_measure/purity.rb +0 -40
- data/lib/rumale/evaluation_measure/r2_score.rb +0 -43
- data/lib/rumale/evaluation_measure/recall.rb +0 -50
- data/lib/rumale/evaluation_measure/roc_auc.rb +0 -130
- data/lib/rumale/evaluation_measure/silhouette_score.rb +0 -82
- data/lib/rumale/feature_extraction/feature_hasher.rb +0 -110
- data/lib/rumale/feature_extraction/hash_vectorizer.rb +0 -155
- data/lib/rumale/feature_extraction/tfidf_transformer.rb +0 -113
- data/lib/rumale/kernel_approximation/nystroem.rb +0 -126
- data/lib/rumale/kernel_approximation/rbf.rb +0 -102
- data/lib/rumale/kernel_machine/kernel_fda.rb +0 -120
- data/lib/rumale/kernel_machine/kernel_pca.rb +0 -97
- data/lib/rumale/kernel_machine/kernel_ridge.rb +0 -82
- data/lib/rumale/kernel_machine/kernel_ridge_classifier.rb +0 -92
- data/lib/rumale/kernel_machine/kernel_svc.rb +0 -193
- data/lib/rumale/linear_model/base_sgd.rb +0 -285
- data/lib/rumale/linear_model/elastic_net.rb +0 -119
- data/lib/rumale/linear_model/lasso.rb +0 -115
- data/lib/rumale/linear_model/linear_regression.rb +0 -201
- data/lib/rumale/linear_model/logistic_regression.rb +0 -275
- data/lib/rumale/linear_model/nnls.rb +0 -137
- data/lib/rumale/linear_model/ridge.rb +0 -209
- data/lib/rumale/linear_model/svc.rb +0 -213
- data/lib/rumale/linear_model/svr.rb +0 -132
- data/lib/rumale/manifold/mds.rb +0 -155
- data/lib/rumale/manifold/tsne.rb +0 -222
- data/lib/rumale/metric_learning/fisher_discriminant_analysis.rb +0 -113
- data/lib/rumale/metric_learning/mlkr.rb +0 -161
- data/lib/rumale/metric_learning/neighbourhood_component_analysis.rb +0 -167
- data/lib/rumale/model_selection/cross_validation.rb +0 -125
- data/lib/rumale/model_selection/function.rb +0 -42
- data/lib/rumale/model_selection/grid_search_cv.rb +0 -225
- data/lib/rumale/model_selection/group_k_fold.rb +0 -93
- data/lib/rumale/model_selection/group_shuffle_split.rb +0 -115
- data/lib/rumale/model_selection/k_fold.rb +0 -81
- data/lib/rumale/model_selection/shuffle_split.rb +0 -90
- data/lib/rumale/model_selection/stratified_k_fold.rb +0 -99
- data/lib/rumale/model_selection/stratified_shuffle_split.rb +0 -118
- data/lib/rumale/model_selection/time_series_split.rb +0 -91
- data/lib/rumale/multiclass/one_vs_rest_classifier.rb +0 -83
- data/lib/rumale/naive_bayes/base_naive_bayes.rb +0 -47
- data/lib/rumale/naive_bayes/bernoulli_nb.rb +0 -82
- data/lib/rumale/naive_bayes/complement_nb.rb +0 -85
- data/lib/rumale/naive_bayes/gaussian_nb.rb +0 -69
- data/lib/rumale/naive_bayes/multinomial_nb.rb +0 -74
- data/lib/rumale/naive_bayes/negation_nb.rb +0 -71
- data/lib/rumale/nearest_neighbors/k_neighbors_classifier.rb +0 -133
- data/lib/rumale/nearest_neighbors/k_neighbors_regressor.rb +0 -108
- data/lib/rumale/nearest_neighbors/vp_tree.rb +0 -132
- data/lib/rumale/neural_network/adam.rb +0 -56
- data/lib/rumale/neural_network/base_mlp.rb +0 -248
- data/lib/rumale/neural_network/mlp_classifier.rb +0 -120
- data/lib/rumale/neural_network/mlp_regressor.rb +0 -90
- data/lib/rumale/pairwise_metric.rb +0 -152
- data/lib/rumale/pipeline/feature_union.rb +0 -69
- data/lib/rumale/pipeline/pipeline.rb +0 -175
- data/lib/rumale/preprocessing/bin_discretizer.rb +0 -93
- data/lib/rumale/preprocessing/binarizer.rb +0 -60
- data/lib/rumale/preprocessing/kernel_calculator.rb +0 -92
- data/lib/rumale/preprocessing/l1_normalizer.rb +0 -62
- data/lib/rumale/preprocessing/l2_normalizer.rb +0 -63
- data/lib/rumale/preprocessing/label_binarizer.rb +0 -89
- data/lib/rumale/preprocessing/label_encoder.rb +0 -79
- data/lib/rumale/preprocessing/max_abs_scaler.rb +0 -61
- data/lib/rumale/preprocessing/max_normalizer.rb +0 -62
- data/lib/rumale/preprocessing/min_max_scaler.rb +0 -76
- data/lib/rumale/preprocessing/one_hot_encoder.rb +0 -100
- data/lib/rumale/preprocessing/ordinal_encoder.rb +0 -109
- data/lib/rumale/preprocessing/polynomial_features.rb +0 -109
- data/lib/rumale/preprocessing/standard_scaler.rb +0 -71
- data/lib/rumale/probabilistic_output.rb +0 -114
- data/lib/rumale/tree/base_decision_tree.rb +0 -150
- data/lib/rumale/tree/decision_tree_classifier.rb +0 -150
- data/lib/rumale/tree/decision_tree_regressor.rb +0 -116
- data/lib/rumale/tree/extra_tree_classifier.rb +0 -107
- data/lib/rumale/tree/extra_tree_regressor.rb +0 -94
- data/lib/rumale/tree/gradient_tree_regressor.rb +0 -202
- data/lib/rumale/tree/node.rb +0 -39
- data/lib/rumale/utils.rb +0 -42
- data/lib/rumale/validation.rb +0 -128
- data/lib/rumale/values.rb +0 -13
|
@@ -1,49 +0,0 @@
|
|
|
1
|
-
# frozen_string_literal: true
|
|
2
|
-
|
|
3
|
-
module Rumale
|
|
4
|
-
# This module consists of basic mix-in classes.
|
|
5
|
-
module Base
|
|
6
|
-
# Base module for all estimators in Rumale.
|
|
7
|
-
module BaseEstimator
|
|
8
|
-
# Return parameters about an estimator.
|
|
9
|
-
# @return [Hash]
|
|
10
|
-
attr_reader :params
|
|
11
|
-
|
|
12
|
-
private
|
|
13
|
-
|
|
14
|
-
def enable_linalg?(warning: true)
|
|
15
|
-
if defined?(Numo::Linalg).nil?
|
|
16
|
-
warn('If you want to use features that depend on Numo::Linalg, you should install and load Numo::Linalg in advance.') if warning
|
|
17
|
-
return false
|
|
18
|
-
end
|
|
19
|
-
if Numo::Linalg::VERSION < '0.1.4'
|
|
20
|
-
if warning
|
|
21
|
-
warn('The loaded Numo::Linalg does not implement the methods required by Rumale. Please load Numo::Linalg version 0.1.4 or later.')
|
|
22
|
-
end
|
|
23
|
-
return false
|
|
24
|
-
end
|
|
25
|
-
true
|
|
26
|
-
end
|
|
27
|
-
|
|
28
|
-
def enable_parallel?
|
|
29
|
-
return false if @params[:n_jobs].nil?
|
|
30
|
-
|
|
31
|
-
if defined?(Parallel).nil?
|
|
32
|
-
warn('If you want to use parallel option, you should install and load Parallel in advance.')
|
|
33
|
-
return false
|
|
34
|
-
end
|
|
35
|
-
true
|
|
36
|
-
end
|
|
37
|
-
|
|
38
|
-
def n_processes
|
|
39
|
-
return 1 unless enable_parallel?
|
|
40
|
-
|
|
41
|
-
@params[:n_jobs] <= 0 ? Parallel.processor_count : @params[:n_jobs]
|
|
42
|
-
end
|
|
43
|
-
|
|
44
|
-
def parallel_map(n_outputs, &block)
|
|
45
|
-
Parallel.map(Array.new(n_outputs) { |v| v }, in_processes: n_processes, &block)
|
|
46
|
-
end
|
|
47
|
-
end
|
|
48
|
-
end
|
|
49
|
-
end
|
|
@@ -1,36 +0,0 @@
|
|
|
1
|
-
# frozen_string_literal: true
|
|
2
|
-
|
|
3
|
-
require 'rumale/validation'
|
|
4
|
-
require 'rumale/evaluation_measure/accuracy'
|
|
5
|
-
|
|
6
|
-
module Rumale
|
|
7
|
-
module Base
|
|
8
|
-
# Module for all classifiers in Rumale.
|
|
9
|
-
module Classifier
|
|
10
|
-
include Validation
|
|
11
|
-
|
|
12
|
-
# An abstract method for fitting a model.
|
|
13
|
-
def fit
|
|
14
|
-
raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
|
|
15
|
-
end
|
|
16
|
-
|
|
17
|
-
# An abstract method for predicting labels.
|
|
18
|
-
def predict
|
|
19
|
-
raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
|
|
20
|
-
end
|
|
21
|
-
|
|
22
|
-
# Calculate the mean accuracy of the given testing data.
|
|
23
|
-
#
|
|
24
|
-
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) Testing data.
|
|
25
|
-
# @param y [Numo::Int32] (shape: [n_samples]) True labels for testing data.
|
|
26
|
-
# @return [Float] Mean accuracy
|
|
27
|
-
def score(x, y)
|
|
28
|
-
x = check_convert_sample_array(x)
|
|
29
|
-
y = check_convert_label_array(y)
|
|
30
|
-
check_sample_label_size(x, y)
|
|
31
|
-
evaluator = Rumale::EvaluationMeasure::Accuracy.new
|
|
32
|
-
evaluator.score(y, predict(x))
|
|
33
|
-
end
|
|
34
|
-
end
|
|
35
|
-
end
|
|
36
|
-
end
|
|
@@ -1,31 +0,0 @@
|
|
|
1
|
-
# frozen_string_literal: true
|
|
2
|
-
|
|
3
|
-
require 'rumale/validation'
|
|
4
|
-
require 'rumale/evaluation_measure/purity'
|
|
5
|
-
|
|
6
|
-
module Rumale
|
|
7
|
-
module Base
|
|
8
|
-
# Module for all clustering algorithms in Rumale.
|
|
9
|
-
module ClusterAnalyzer
|
|
10
|
-
include Validation
|
|
11
|
-
|
|
12
|
-
# An abstract method for analyzing clusters and predicting cluster indices.
|
|
13
|
-
def fit_predict
|
|
14
|
-
raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
|
|
15
|
-
end
|
|
16
|
-
|
|
17
|
-
# Calculate purity of clustering result.
|
|
18
|
-
#
|
|
19
|
-
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) Testing data.
|
|
20
|
-
# @param y [Numo::Int32] (shape: [n_samples]) True labels for testing data.
|
|
21
|
-
# @return [Float] Purity
|
|
22
|
-
def score(x, y)
|
|
23
|
-
x = check_convert_sample_array(x)
|
|
24
|
-
y = check_convert_label_array(y)
|
|
25
|
-
check_sample_label_size(x, y)
|
|
26
|
-
evaluator = Rumale::EvaluationMeasure::Purity.new
|
|
27
|
-
evaluator.score(y, fit_predict(x))
|
|
28
|
-
end
|
|
29
|
-
end
|
|
30
|
-
end
|
|
31
|
-
end
|
|
@@ -1,17 +0,0 @@
|
|
|
1
|
-
# frozen_string_literal: true
|
|
2
|
-
|
|
3
|
-
require 'rumale/validation'
|
|
4
|
-
|
|
5
|
-
module Rumale
|
|
6
|
-
module Base
|
|
7
|
-
# Module for all evaluation measures in Rumale.
|
|
8
|
-
module Evaluator
|
|
9
|
-
include Validation
|
|
10
|
-
|
|
11
|
-
# An abstract method for evaluation of model.
|
|
12
|
-
def score
|
|
13
|
-
raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
|
|
14
|
-
end
|
|
15
|
-
end
|
|
16
|
-
end
|
|
17
|
-
end
|
|
@@ -1,36 +0,0 @@
|
|
|
1
|
-
# frozen_string_literal: true
|
|
2
|
-
|
|
3
|
-
require 'rumale/validation'
|
|
4
|
-
require 'rumale/evaluation_measure/r2_score'
|
|
5
|
-
|
|
6
|
-
module Rumale
|
|
7
|
-
module Base
|
|
8
|
-
# Module for all regressors in Rumale.
|
|
9
|
-
module Regressor
|
|
10
|
-
include Validation
|
|
11
|
-
|
|
12
|
-
# An abstract method for fitting a model.
|
|
13
|
-
def fit
|
|
14
|
-
raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
|
|
15
|
-
end
|
|
16
|
-
|
|
17
|
-
# An abstract method for predicting labels.
|
|
18
|
-
def predict
|
|
19
|
-
raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
|
|
20
|
-
end
|
|
21
|
-
|
|
22
|
-
# Calculate the coefficient of determination for the given testing data.
|
|
23
|
-
#
|
|
24
|
-
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) Testing data.
|
|
25
|
-
# @param y [Numo::DFloat] (shape: [n_samples, n_outputs]) Target values for testing data.
|
|
26
|
-
# @return [Float] Coefficient of determination
|
|
27
|
-
def score(x, y)
|
|
28
|
-
x = check_convert_sample_array(x)
|
|
29
|
-
y = check_convert_tvalue_array(y)
|
|
30
|
-
check_sample_tvalue_size(x, y)
|
|
31
|
-
evaluator = Rumale::EvaluationMeasure::R2Score.new
|
|
32
|
-
evaluator.score(y, predict(x))
|
|
33
|
-
end
|
|
34
|
-
end
|
|
35
|
-
end
|
|
36
|
-
end
|
data/lib/rumale/base/splitter.rb
DELETED
|
@@ -1,21 +0,0 @@
|
|
|
1
|
-
# frozen_string_literal: true
|
|
2
|
-
|
|
3
|
-
require 'rumale/validation'
|
|
4
|
-
|
|
5
|
-
module Rumale
|
|
6
|
-
module Base
|
|
7
|
-
# Module for all validation methods in Rumale.
|
|
8
|
-
module Splitter
|
|
9
|
-
include Validation
|
|
10
|
-
|
|
11
|
-
# Return the number of splits.
|
|
12
|
-
# @return [Integer]
|
|
13
|
-
attr_reader :n_splits
|
|
14
|
-
|
|
15
|
-
# An abstract method for splitting dataset.
|
|
16
|
-
def split
|
|
17
|
-
raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
|
|
18
|
-
end
|
|
19
|
-
end
|
|
20
|
-
end
|
|
21
|
-
end
|
|
@@ -1,22 +0,0 @@
|
|
|
1
|
-
# frozen_string_literal: true
|
|
2
|
-
|
|
3
|
-
require 'rumale/validation'
|
|
4
|
-
|
|
5
|
-
module Rumale
|
|
6
|
-
module Base
|
|
7
|
-
# Module for all transfomers in Rumale.
|
|
8
|
-
module Transformer
|
|
9
|
-
include Validation
|
|
10
|
-
|
|
11
|
-
# An abstract method for fitting a model.
|
|
12
|
-
def fit
|
|
13
|
-
raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
|
|
14
|
-
end
|
|
15
|
-
|
|
16
|
-
# An abstract method for fitting a model and transforming given data.
|
|
17
|
-
def fit_transform
|
|
18
|
-
raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
|
|
19
|
-
end
|
|
20
|
-
end
|
|
21
|
-
end
|
|
22
|
-
end
|
|
@@ -1,123 +0,0 @@
|
|
|
1
|
-
# frozen_string_literal: true
|
|
2
|
-
|
|
3
|
-
require 'rumale/base/base_estimator'
|
|
4
|
-
require 'rumale/base/cluster_analyzer'
|
|
5
|
-
require 'rumale/pairwise_metric'
|
|
6
|
-
|
|
7
|
-
module Rumale
|
|
8
|
-
module Clustering
|
|
9
|
-
# DBSCAN is a class that implements DBSCAN cluster analysis.
|
|
10
|
-
#
|
|
11
|
-
# @example
|
|
12
|
-
# analyzer = Rumale::Clustering::DBSCAN.new(eps: 0.5, min_samples: 5)
|
|
13
|
-
# cluster_labels = analyzer.fit_predict(samples)
|
|
14
|
-
#
|
|
15
|
-
# *Reference*
|
|
16
|
-
# - Ester, M., Kriegel, H-P., Sander, J., and Xu, X., "A density-based algorithm for discovering clusters in large spatial databases with noise," Proc. KDD' 96, pp. 266--231, 1996.
|
|
17
|
-
class DBSCAN
|
|
18
|
-
include Base::BaseEstimator
|
|
19
|
-
include Base::ClusterAnalyzer
|
|
20
|
-
|
|
21
|
-
# Return the core sample indices.
|
|
22
|
-
# @return [Numo::Int32] (shape: [n_core_samples])
|
|
23
|
-
attr_reader :core_sample_ids
|
|
24
|
-
|
|
25
|
-
# Return the cluster labels. The negative cluster label indicates that the point is noise.
|
|
26
|
-
# @return [Numo::Int32] (shape: [n_samples])
|
|
27
|
-
attr_reader :labels
|
|
28
|
-
|
|
29
|
-
# Create a new cluster analyzer with DBSCAN method.
|
|
30
|
-
#
|
|
31
|
-
# @param eps [Float] The radius of neighborhood.
|
|
32
|
-
# @param min_samples [Integer] The number of neighbor samples to be used for the criterion whether a point is a core point.
|
|
33
|
-
# @param metric [String] The metric to calculate the distances.
|
|
34
|
-
# If metric is 'euclidean', Euclidean distance is calculated for distance between points.
|
|
35
|
-
# If metric is 'precomputed', the fit and fit_transform methods expect to be given a distance matrix.
|
|
36
|
-
def initialize(eps: 0.5, min_samples: 5, metric: 'euclidean')
|
|
37
|
-
check_params_numeric(eps: eps, min_samples: min_samples)
|
|
38
|
-
check_params_string(metric: metric)
|
|
39
|
-
@params = {}
|
|
40
|
-
@params[:eps] = eps
|
|
41
|
-
@params[:min_samples] = min_samples
|
|
42
|
-
@params[:metric] = metric == 'precomputed' ? 'precomputed' : 'euclidean'
|
|
43
|
-
@core_sample_ids = nil
|
|
44
|
-
@labels = nil
|
|
45
|
-
end
|
|
46
|
-
|
|
47
|
-
# Analysis clusters with given training data.
|
|
48
|
-
#
|
|
49
|
-
# @overload fit(x) -> DBSCAN
|
|
50
|
-
#
|
|
51
|
-
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
|
|
52
|
-
# If the metric is 'precomputed', x must be a square distance matrix (shape: [n_samples, n_samples]).
|
|
53
|
-
# @return [DBSCAN] The learned cluster analyzer itself.
|
|
54
|
-
def fit(x, _y = nil)
|
|
55
|
-
x = check_convert_sample_array(x)
|
|
56
|
-
raise ArgumentError, 'Expect the input distance matrix to be square.' if @params[:metric] == 'precomputed' && x.shape[0] != x.shape[1]
|
|
57
|
-
|
|
58
|
-
partial_fit(x)
|
|
59
|
-
self
|
|
60
|
-
end
|
|
61
|
-
|
|
62
|
-
# Analysis clusters and assign samples to clusters.
|
|
63
|
-
#
|
|
64
|
-
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to be used for cluster analysis.
|
|
65
|
-
# If the metric is 'precomputed', x must be a square distance matrix (shape: [n_samples, n_samples]).
|
|
66
|
-
# @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
|
|
67
|
-
def fit_predict(x)
|
|
68
|
-
x = check_convert_sample_array(x)
|
|
69
|
-
raise ArgumentError, 'Expect the input distance matrix to be square.' if @params[:metric] == 'precomputed' && x.shape[0] != x.shape[1]
|
|
70
|
-
|
|
71
|
-
partial_fit(x)
|
|
72
|
-
labels
|
|
73
|
-
end
|
|
74
|
-
|
|
75
|
-
private
|
|
76
|
-
|
|
77
|
-
def partial_fit(x)
|
|
78
|
-
cluster_id = 0
|
|
79
|
-
metric_mat = calc_pairwise_metrics(x)
|
|
80
|
-
n_samples = metric_mat.shape[0]
|
|
81
|
-
@core_sample_ids = []
|
|
82
|
-
@labels = Numo::Int32.zeros(n_samples) - 2
|
|
83
|
-
n_samples.times do |query_id|
|
|
84
|
-
next if @labels[query_id] >= -1
|
|
85
|
-
|
|
86
|
-
cluster_id += 1 if expand_cluster(metric_mat, query_id, cluster_id)
|
|
87
|
-
end
|
|
88
|
-
@core_sample_ids = Numo::Int32[*@core_sample_ids.flatten]
|
|
89
|
-
nil
|
|
90
|
-
end
|
|
91
|
-
|
|
92
|
-
def calc_pairwise_metrics(x)
|
|
93
|
-
@params[:metric] == 'precomputed' ? x : Rumale::PairwiseMetric.euclidean_distance(x)
|
|
94
|
-
end
|
|
95
|
-
|
|
96
|
-
def expand_cluster(metric_mat, query_id, cluster_id)
|
|
97
|
-
target_ids = region_query(metric_mat[query_id, true])
|
|
98
|
-
if target_ids.size < @params[:min_samples]
|
|
99
|
-
@labels[query_id] = -1
|
|
100
|
-
false
|
|
101
|
-
else
|
|
102
|
-
@labels[target_ids] = cluster_id
|
|
103
|
-
@core_sample_ids.push(target_ids.dup)
|
|
104
|
-
target_ids.delete(query_id)
|
|
105
|
-
while (m = target_ids.shift)
|
|
106
|
-
neighbor_ids = region_query(metric_mat[m, true])
|
|
107
|
-
next if neighbor_ids.size < @params[:min_samples]
|
|
108
|
-
|
|
109
|
-
neighbor_ids.each do |n|
|
|
110
|
-
target_ids.push(n) if @labels[n] < -1
|
|
111
|
-
@labels[n] = cluster_id if @labels[n] <= -1
|
|
112
|
-
end
|
|
113
|
-
end
|
|
114
|
-
true
|
|
115
|
-
end
|
|
116
|
-
end
|
|
117
|
-
|
|
118
|
-
def region_query(metric_arr)
|
|
119
|
-
metric_arr.lt(@params[:eps]).where.to_a
|
|
120
|
-
end
|
|
121
|
-
end
|
|
122
|
-
end
|
|
123
|
-
end
|
|
@@ -1,218 +0,0 @@
|
|
|
1
|
-
# frozen_string_literal: true
|
|
2
|
-
|
|
3
|
-
require 'rumale/base/base_estimator'
|
|
4
|
-
require 'rumale/base/cluster_analyzer'
|
|
5
|
-
require 'rumale/preprocessing/label_binarizer'
|
|
6
|
-
|
|
7
|
-
module Rumale
|
|
8
|
-
module Clustering
|
|
9
|
-
# GaussianMixture is a class that implements cluster analysis with gaussian mixture model.
|
|
10
|
-
#
|
|
11
|
-
# @example
|
|
12
|
-
# analyzer = Rumale::Clustering::GaussianMixture.new(n_clusters: 10, max_iter: 50)
|
|
13
|
-
# cluster_labels = analyzer.fit_predict(samples)
|
|
14
|
-
#
|
|
15
|
-
# # If Numo::Linalg is installed, you can specify 'full' for the tyep of covariance option.
|
|
16
|
-
# require 'numo/linalg/autoloader'
|
|
17
|
-
# analyzer = Rumale::Clustering::GaussianMixture.new(n_clusters: 10, max_iter: 50, covariance_type: 'full')
|
|
18
|
-
# cluster_labels = analyzer.fit_predict(samples)
|
|
19
|
-
#
|
|
20
|
-
class GaussianMixture
|
|
21
|
-
include Base::BaseEstimator
|
|
22
|
-
include Base::ClusterAnalyzer
|
|
23
|
-
|
|
24
|
-
# Return the number of iterations to covergence.
|
|
25
|
-
# @return [Integer]
|
|
26
|
-
attr_reader :n_iter
|
|
27
|
-
|
|
28
|
-
# Return the weight of each cluster.
|
|
29
|
-
# @return [Numo::DFloat] (shape: [n_clusters])
|
|
30
|
-
attr_reader :weights
|
|
31
|
-
|
|
32
|
-
# Return the mean of each cluster.
|
|
33
|
-
# @return [Numo::DFloat] (shape: [n_clusters, n_features])
|
|
34
|
-
attr_reader :means
|
|
35
|
-
|
|
36
|
-
# Return the diagonal elements of covariance matrix of each cluster.
|
|
37
|
-
# @return [Numo::DFloat] (shape: [n_clusters, n_features] if 'diag', [n_clusters, n_features, n_features] if 'full')
|
|
38
|
-
attr_reader :covariances
|
|
39
|
-
|
|
40
|
-
# Create a new cluster analyzer with gaussian mixture model.
|
|
41
|
-
#
|
|
42
|
-
# @param n_clusters [Integer] The number of clusters.
|
|
43
|
-
# @param init [String] The initialization method for centroids ('random' or 'k-means++').
|
|
44
|
-
# @param covariance_type [String] The type of covariance parameter to be used ('diag' or 'full').
|
|
45
|
-
# @param max_iter [Integer] The maximum number of iterations.
|
|
46
|
-
# @param tol [Float] The tolerance of termination criterion.
|
|
47
|
-
# @param reg_covar [Float] The non-negative regularization to the diagonal of covariance.
|
|
48
|
-
# @param random_seed [Integer] The seed value using to initialize the random generator.
|
|
49
|
-
def initialize(n_clusters: 8, init: 'k-means++', covariance_type: 'diag', max_iter: 50, tol: 1.0e-4, reg_covar: 1.0e-6, random_seed: nil)
|
|
50
|
-
check_params_numeric(n_clusters: n_clusters, max_iter: max_iter, tol: tol)
|
|
51
|
-
check_params_string(init: init)
|
|
52
|
-
check_params_numeric_or_nil(random_seed: random_seed)
|
|
53
|
-
check_params_positive(n_clusters: n_clusters, max_iter: max_iter)
|
|
54
|
-
@params = {}
|
|
55
|
-
@params[:n_clusters] = n_clusters
|
|
56
|
-
@params[:init] = init == 'random' ? 'random' : 'k-means++'
|
|
57
|
-
@params[:covariance_type] = covariance_type == 'full' ? 'full' : 'diag'
|
|
58
|
-
@params[:max_iter] = max_iter
|
|
59
|
-
@params[:tol] = tol
|
|
60
|
-
@params[:reg_covar] = reg_covar
|
|
61
|
-
@params[:random_seed] = random_seed
|
|
62
|
-
@params[:random_seed] ||= srand
|
|
63
|
-
@n_iter = nil
|
|
64
|
-
@weights = nil
|
|
65
|
-
@means = nil
|
|
66
|
-
@covariances = nil
|
|
67
|
-
end
|
|
68
|
-
|
|
69
|
-
# Analysis clusters with given training data.
|
|
70
|
-
#
|
|
71
|
-
# @overload fit(x) -> GaussianMixture
|
|
72
|
-
#
|
|
73
|
-
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
|
|
74
|
-
# @return [GaussianMixture] The learned cluster analyzer itself.
|
|
75
|
-
def fit(x, _y = nil)
|
|
76
|
-
x = check_convert_sample_array(x)
|
|
77
|
-
check_enable_linalg('fit')
|
|
78
|
-
|
|
79
|
-
n_samples = x.shape[0]
|
|
80
|
-
memberships = init_memberships(x)
|
|
81
|
-
@params[:max_iter].times do |t|
|
|
82
|
-
@n_iter = t
|
|
83
|
-
@weights = calc_weights(n_samples, memberships)
|
|
84
|
-
@means = calc_means(x, memberships)
|
|
85
|
-
@covariances = calc_covariances(x, @means, memberships, @params[:reg_covar], @params[:covariance_type])
|
|
86
|
-
new_memberships = calc_memberships(x, @weights, @means, @covariances, @params[:covariance_type])
|
|
87
|
-
error = (memberships - new_memberships).abs.max
|
|
88
|
-
break if error <= @params[:tol]
|
|
89
|
-
|
|
90
|
-
memberships = new_memberships.dup
|
|
91
|
-
end
|
|
92
|
-
self
|
|
93
|
-
end
|
|
94
|
-
|
|
95
|
-
# Predict cluster labels for samples.
|
|
96
|
-
#
|
|
97
|
-
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the cluster label.
|
|
98
|
-
# @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
|
|
99
|
-
def predict(x)
|
|
100
|
-
x = check_convert_sample_array(x)
|
|
101
|
-
check_enable_linalg('predict')
|
|
102
|
-
|
|
103
|
-
memberships = calc_memberships(x, @weights, @means, @covariances, @params[:covariance_type])
|
|
104
|
-
assign_cluster(memberships)
|
|
105
|
-
end
|
|
106
|
-
|
|
107
|
-
# Analysis clusters and assign samples to clusters.
|
|
108
|
-
#
|
|
109
|
-
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
|
|
110
|
-
# @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
|
|
111
|
-
def fit_predict(x)
|
|
112
|
-
x = check_convert_sample_array(x)
|
|
113
|
-
check_enable_linalg('fit_predict')
|
|
114
|
-
|
|
115
|
-
fit(x).predict(x)
|
|
116
|
-
end
|
|
117
|
-
|
|
118
|
-
private
|
|
119
|
-
|
|
120
|
-
def assign_cluster(memberships)
|
|
121
|
-
n_clusters = memberships.shape[1]
|
|
122
|
-
memberships.max_index(axis: 1) - Numo::Int32[*0.step(memberships.size - 1, n_clusters)]
|
|
123
|
-
end
|
|
124
|
-
|
|
125
|
-
def init_memberships(x)
|
|
126
|
-
kmeans = Rumale::Clustering::KMeans.new(
|
|
127
|
-
n_clusters: @params[:n_clusters], init: @params[:init], max_iter: 0, random_seed: @params[:random_seed]
|
|
128
|
-
)
|
|
129
|
-
cluster_ids = kmeans.fit_predict(x)
|
|
130
|
-
encoder = Rumale::Preprocessing::LabelBinarizer.new
|
|
131
|
-
Numo::DFloat.cast(encoder.fit_transform(cluster_ids))
|
|
132
|
-
end
|
|
133
|
-
|
|
134
|
-
def calc_memberships(x, weights, means, covars, covar_type)
|
|
135
|
-
n_samples = x.shape[0]
|
|
136
|
-
n_clusters = means.shape[0]
|
|
137
|
-
memberships = Numo::DFloat.zeros(n_samples, n_clusters)
|
|
138
|
-
n_clusters.times do |n|
|
|
139
|
-
centered = x - means[n, true]
|
|
140
|
-
covar = covar_type == 'full' ? covars[n, true, true] : covars[n, true]
|
|
141
|
-
memberships[true, n] = calc_unnormalized_membership(centered, weights[n], covar, covar_type)
|
|
142
|
-
end
|
|
143
|
-
memberships / memberships.sum(1).expand_dims(1)
|
|
144
|
-
end
|
|
145
|
-
|
|
146
|
-
def calc_weights(n_samples, memberships)
|
|
147
|
-
memberships.sum(0) / n_samples
|
|
148
|
-
end
|
|
149
|
-
|
|
150
|
-
def calc_means(x, memberships)
|
|
151
|
-
memberships.transpose.dot(x) / memberships.sum(0).expand_dims(1)
|
|
152
|
-
end
|
|
153
|
-
|
|
154
|
-
def calc_covariances(x, means, memberships, reg_cover, covar_type)
|
|
155
|
-
if covar_type == 'full'
|
|
156
|
-
calc_full_covariances(x, means, reg_cover, memberships)
|
|
157
|
-
else
|
|
158
|
-
calc_diag_covariances(x, means, reg_cover, memberships)
|
|
159
|
-
end
|
|
160
|
-
end
|
|
161
|
-
|
|
162
|
-
def calc_diag_covariances(x, means, reg_cover, memberships)
|
|
163
|
-
n_clusters = means.shape[0]
|
|
164
|
-
diag_cov = Array.new(n_clusters) do |n|
|
|
165
|
-
centered = x - means[n, true]
|
|
166
|
-
memberships[true, n].dot(centered**2) / memberships[true, n].sum
|
|
167
|
-
end
|
|
168
|
-
Numo::DFloat.asarray(diag_cov) + reg_cover
|
|
169
|
-
end
|
|
170
|
-
|
|
171
|
-
def calc_full_covariances(x, means, reg_cover, memberships)
|
|
172
|
-
n_features = x.shape[1]
|
|
173
|
-
n_clusters = means.shape[0]
|
|
174
|
-
cov_mats = Numo::DFloat.zeros(n_clusters, n_features, n_features)
|
|
175
|
-
reg_mat = Numo::DFloat.eye(n_features) * reg_cover
|
|
176
|
-
n_clusters.times do |n|
|
|
177
|
-
centered = x - means[n, true]
|
|
178
|
-
members = memberships[true, n]
|
|
179
|
-
cov_mats[n, true, true] = reg_mat + (centered.transpose * members).dot(centered) / members.sum
|
|
180
|
-
end
|
|
181
|
-
cov_mats
|
|
182
|
-
end
|
|
183
|
-
|
|
184
|
-
def calc_unnormalized_membership(centered, weight, covar, covar_type)
|
|
185
|
-
inv_covar = calc_inv_covariance(covar, covar_type)
|
|
186
|
-
inv_sqrt_det_covar = calc_inv_sqrt_det_covariance(covar, covar_type)
|
|
187
|
-
distances = if covar_type == 'full'
|
|
188
|
-
(centered.dot(inv_covar) * centered).sum(1)
|
|
189
|
-
else
|
|
190
|
-
(centered * inv_covar * centered).sum(1)
|
|
191
|
-
end
|
|
192
|
-
weight * inv_sqrt_det_covar * Numo::NMath.exp(-0.5 * distances)
|
|
193
|
-
end
|
|
194
|
-
|
|
195
|
-
def calc_inv_covariance(covar, covar_type)
|
|
196
|
-
if covar_type == 'full'
|
|
197
|
-
Numo::Linalg.inv(covar)
|
|
198
|
-
else
|
|
199
|
-
1.0 / covar
|
|
200
|
-
end
|
|
201
|
-
end
|
|
202
|
-
|
|
203
|
-
def calc_inv_sqrt_det_covariance(covar, covar_type)
|
|
204
|
-
if covar_type == 'full'
|
|
205
|
-
1.0 / Math.sqrt(Numo::Linalg.det(covar))
|
|
206
|
-
else
|
|
207
|
-
1.0 / Math.sqrt(covar.prod)
|
|
208
|
-
end
|
|
209
|
-
end
|
|
210
|
-
|
|
211
|
-
def check_enable_linalg(method_name)
|
|
212
|
-
return unless @params[:covariance_type] == 'full' && !enable_linalg?
|
|
213
|
-
|
|
214
|
-
raise "GaussianMixture##{method_name} requires Numo::Linalg when covariance_type is 'full' but that is not loaded."
|
|
215
|
-
end
|
|
216
|
-
end
|
|
217
|
-
end
|
|
218
|
-
end
|