rumale 0.23.3 → 0.24.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (142) hide show
  1. checksums.yaml +4 -4
  2. data/LICENSE.txt +5 -1
  3. data/README.md +3 -288
  4. data/lib/rumale/version.rb +1 -1
  5. data/lib/rumale.rb +20 -131
  6. metadata +252 -150
  7. data/CHANGELOG.md +0 -643
  8. data/CODE_OF_CONDUCT.md +0 -74
  9. data/ext/rumale/extconf.rb +0 -37
  10. data/ext/rumale/rumaleext.c +0 -545
  11. data/ext/rumale/rumaleext.h +0 -12
  12. data/lib/rumale/base/base_estimator.rb +0 -49
  13. data/lib/rumale/base/classifier.rb +0 -36
  14. data/lib/rumale/base/cluster_analyzer.rb +0 -31
  15. data/lib/rumale/base/evaluator.rb +0 -17
  16. data/lib/rumale/base/regressor.rb +0 -36
  17. data/lib/rumale/base/splitter.rb +0 -21
  18. data/lib/rumale/base/transformer.rb +0 -22
  19. data/lib/rumale/clustering/dbscan.rb +0 -123
  20. data/lib/rumale/clustering/gaussian_mixture.rb +0 -218
  21. data/lib/rumale/clustering/hdbscan.rb +0 -291
  22. data/lib/rumale/clustering/k_means.rb +0 -122
  23. data/lib/rumale/clustering/k_medoids.rb +0 -141
  24. data/lib/rumale/clustering/mini_batch_k_means.rb +0 -139
  25. data/lib/rumale/clustering/power_iteration.rb +0 -127
  26. data/lib/rumale/clustering/single_linkage.rb +0 -203
  27. data/lib/rumale/clustering/snn.rb +0 -76
  28. data/lib/rumale/clustering/spectral_clustering.rb +0 -115
  29. data/lib/rumale/dataset.rb +0 -246
  30. data/lib/rumale/decomposition/factor_analysis.rb +0 -150
  31. data/lib/rumale/decomposition/fast_ica.rb +0 -188
  32. data/lib/rumale/decomposition/nmf.rb +0 -124
  33. data/lib/rumale/decomposition/pca.rb +0 -159
  34. data/lib/rumale/ensemble/ada_boost_classifier.rb +0 -179
  35. data/lib/rumale/ensemble/ada_boost_regressor.rb +0 -160
  36. data/lib/rumale/ensemble/extra_trees_classifier.rb +0 -139
  37. data/lib/rumale/ensemble/extra_trees_regressor.rb +0 -125
  38. data/lib/rumale/ensemble/gradient_boosting_classifier.rb +0 -306
  39. data/lib/rumale/ensemble/gradient_boosting_regressor.rb +0 -237
  40. data/lib/rumale/ensemble/random_forest_classifier.rb +0 -189
  41. data/lib/rumale/ensemble/random_forest_regressor.rb +0 -153
  42. data/lib/rumale/ensemble/stacking_classifier.rb +0 -215
  43. data/lib/rumale/ensemble/stacking_regressor.rb +0 -163
  44. data/lib/rumale/ensemble/voting_classifier.rb +0 -126
  45. data/lib/rumale/ensemble/voting_regressor.rb +0 -82
  46. data/lib/rumale/evaluation_measure/accuracy.rb +0 -29
  47. data/lib/rumale/evaluation_measure/adjusted_rand_score.rb +0 -74
  48. data/lib/rumale/evaluation_measure/calinski_harabasz_score.rb +0 -56
  49. data/lib/rumale/evaluation_measure/davies_bouldin_score.rb +0 -53
  50. data/lib/rumale/evaluation_measure/explained_variance_score.rb +0 -39
  51. data/lib/rumale/evaluation_measure/f_score.rb +0 -50
  52. data/lib/rumale/evaluation_measure/function.rb +0 -147
  53. data/lib/rumale/evaluation_measure/log_loss.rb +0 -45
  54. data/lib/rumale/evaluation_measure/mean_absolute_error.rb +0 -29
  55. data/lib/rumale/evaluation_measure/mean_squared_error.rb +0 -29
  56. data/lib/rumale/evaluation_measure/mean_squared_log_error.rb +0 -29
  57. data/lib/rumale/evaluation_measure/median_absolute_error.rb +0 -30
  58. data/lib/rumale/evaluation_measure/mutual_information.rb +0 -49
  59. data/lib/rumale/evaluation_measure/normalized_mutual_information.rb +0 -53
  60. data/lib/rumale/evaluation_measure/precision.rb +0 -50
  61. data/lib/rumale/evaluation_measure/precision_recall.rb +0 -96
  62. data/lib/rumale/evaluation_measure/purity.rb +0 -40
  63. data/lib/rumale/evaluation_measure/r2_score.rb +0 -43
  64. data/lib/rumale/evaluation_measure/recall.rb +0 -50
  65. data/lib/rumale/evaluation_measure/roc_auc.rb +0 -130
  66. data/lib/rumale/evaluation_measure/silhouette_score.rb +0 -82
  67. data/lib/rumale/feature_extraction/feature_hasher.rb +0 -110
  68. data/lib/rumale/feature_extraction/hash_vectorizer.rb +0 -155
  69. data/lib/rumale/feature_extraction/tfidf_transformer.rb +0 -113
  70. data/lib/rumale/kernel_approximation/nystroem.rb +0 -126
  71. data/lib/rumale/kernel_approximation/rbf.rb +0 -102
  72. data/lib/rumale/kernel_machine/kernel_fda.rb +0 -120
  73. data/lib/rumale/kernel_machine/kernel_pca.rb +0 -97
  74. data/lib/rumale/kernel_machine/kernel_ridge.rb +0 -82
  75. data/lib/rumale/kernel_machine/kernel_ridge_classifier.rb +0 -92
  76. data/lib/rumale/kernel_machine/kernel_svc.rb +0 -193
  77. data/lib/rumale/linear_model/base_sgd.rb +0 -285
  78. data/lib/rumale/linear_model/elastic_net.rb +0 -119
  79. data/lib/rumale/linear_model/lasso.rb +0 -115
  80. data/lib/rumale/linear_model/linear_regression.rb +0 -201
  81. data/lib/rumale/linear_model/logistic_regression.rb +0 -275
  82. data/lib/rumale/linear_model/nnls.rb +0 -137
  83. data/lib/rumale/linear_model/ridge.rb +0 -209
  84. data/lib/rumale/linear_model/svc.rb +0 -213
  85. data/lib/rumale/linear_model/svr.rb +0 -132
  86. data/lib/rumale/manifold/mds.rb +0 -155
  87. data/lib/rumale/manifold/tsne.rb +0 -222
  88. data/lib/rumale/metric_learning/fisher_discriminant_analysis.rb +0 -113
  89. data/lib/rumale/metric_learning/mlkr.rb +0 -161
  90. data/lib/rumale/metric_learning/neighbourhood_component_analysis.rb +0 -167
  91. data/lib/rumale/model_selection/cross_validation.rb +0 -125
  92. data/lib/rumale/model_selection/function.rb +0 -42
  93. data/lib/rumale/model_selection/grid_search_cv.rb +0 -225
  94. data/lib/rumale/model_selection/group_k_fold.rb +0 -93
  95. data/lib/rumale/model_selection/group_shuffle_split.rb +0 -115
  96. data/lib/rumale/model_selection/k_fold.rb +0 -81
  97. data/lib/rumale/model_selection/shuffle_split.rb +0 -90
  98. data/lib/rumale/model_selection/stratified_k_fold.rb +0 -99
  99. data/lib/rumale/model_selection/stratified_shuffle_split.rb +0 -118
  100. data/lib/rumale/model_selection/time_series_split.rb +0 -91
  101. data/lib/rumale/multiclass/one_vs_rest_classifier.rb +0 -83
  102. data/lib/rumale/naive_bayes/base_naive_bayes.rb +0 -47
  103. data/lib/rumale/naive_bayes/bernoulli_nb.rb +0 -82
  104. data/lib/rumale/naive_bayes/complement_nb.rb +0 -85
  105. data/lib/rumale/naive_bayes/gaussian_nb.rb +0 -69
  106. data/lib/rumale/naive_bayes/multinomial_nb.rb +0 -74
  107. data/lib/rumale/naive_bayes/negation_nb.rb +0 -71
  108. data/lib/rumale/nearest_neighbors/k_neighbors_classifier.rb +0 -133
  109. data/lib/rumale/nearest_neighbors/k_neighbors_regressor.rb +0 -108
  110. data/lib/rumale/nearest_neighbors/vp_tree.rb +0 -132
  111. data/lib/rumale/neural_network/adam.rb +0 -56
  112. data/lib/rumale/neural_network/base_mlp.rb +0 -248
  113. data/lib/rumale/neural_network/mlp_classifier.rb +0 -120
  114. data/lib/rumale/neural_network/mlp_regressor.rb +0 -90
  115. data/lib/rumale/pairwise_metric.rb +0 -152
  116. data/lib/rumale/pipeline/feature_union.rb +0 -69
  117. data/lib/rumale/pipeline/pipeline.rb +0 -175
  118. data/lib/rumale/preprocessing/bin_discretizer.rb +0 -93
  119. data/lib/rumale/preprocessing/binarizer.rb +0 -60
  120. data/lib/rumale/preprocessing/kernel_calculator.rb +0 -92
  121. data/lib/rumale/preprocessing/l1_normalizer.rb +0 -62
  122. data/lib/rumale/preprocessing/l2_normalizer.rb +0 -63
  123. data/lib/rumale/preprocessing/label_binarizer.rb +0 -89
  124. data/lib/rumale/preprocessing/label_encoder.rb +0 -79
  125. data/lib/rumale/preprocessing/max_abs_scaler.rb +0 -61
  126. data/lib/rumale/preprocessing/max_normalizer.rb +0 -62
  127. data/lib/rumale/preprocessing/min_max_scaler.rb +0 -76
  128. data/lib/rumale/preprocessing/one_hot_encoder.rb +0 -100
  129. data/lib/rumale/preprocessing/ordinal_encoder.rb +0 -109
  130. data/lib/rumale/preprocessing/polynomial_features.rb +0 -109
  131. data/lib/rumale/preprocessing/standard_scaler.rb +0 -71
  132. data/lib/rumale/probabilistic_output.rb +0 -114
  133. data/lib/rumale/tree/base_decision_tree.rb +0 -150
  134. data/lib/rumale/tree/decision_tree_classifier.rb +0 -150
  135. data/lib/rumale/tree/decision_tree_regressor.rb +0 -116
  136. data/lib/rumale/tree/extra_tree_classifier.rb +0 -107
  137. data/lib/rumale/tree/extra_tree_regressor.rb +0 -94
  138. data/lib/rumale/tree/gradient_tree_regressor.rb +0 -202
  139. data/lib/rumale/tree/node.rb +0 -39
  140. data/lib/rumale/utils.rb +0 -42
  141. data/lib/rumale/validation.rb +0 -128
  142. data/lib/rumale/values.rb +0 -13
@@ -1,49 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- module Rumale
4
- # This module consists of basic mix-in classes.
5
- module Base
6
- # Base module for all estimators in Rumale.
7
- module BaseEstimator
8
- # Return parameters about an estimator.
9
- # @return [Hash]
10
- attr_reader :params
11
-
12
- private
13
-
14
- def enable_linalg?(warning: true)
15
- if defined?(Numo::Linalg).nil?
16
- warn('If you want to use features that depend on Numo::Linalg, you should install and load Numo::Linalg in advance.') if warning
17
- return false
18
- end
19
- if Numo::Linalg::VERSION < '0.1.4'
20
- if warning
21
- warn('The loaded Numo::Linalg does not implement the methods required by Rumale. Please load Numo::Linalg version 0.1.4 or later.')
22
- end
23
- return false
24
- end
25
- true
26
- end
27
-
28
- def enable_parallel?
29
- return false if @params[:n_jobs].nil?
30
-
31
- if defined?(Parallel).nil?
32
- warn('If you want to use parallel option, you should install and load Parallel in advance.')
33
- return false
34
- end
35
- true
36
- end
37
-
38
- def n_processes
39
- return 1 unless enable_parallel?
40
-
41
- @params[:n_jobs] <= 0 ? Parallel.processor_count : @params[:n_jobs]
42
- end
43
-
44
- def parallel_map(n_outputs, &block)
45
- Parallel.map(Array.new(n_outputs) { |v| v }, in_processes: n_processes, &block)
46
- end
47
- end
48
- end
49
- end
@@ -1,36 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'rumale/validation'
4
- require 'rumale/evaluation_measure/accuracy'
5
-
6
- module Rumale
7
- module Base
8
- # Module for all classifiers in Rumale.
9
- module Classifier
10
- include Validation
11
-
12
- # An abstract method for fitting a model.
13
- def fit
14
- raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
15
- end
16
-
17
- # An abstract method for predicting labels.
18
- def predict
19
- raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
20
- end
21
-
22
- # Calculate the mean accuracy of the given testing data.
23
- #
24
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) Testing data.
25
- # @param y [Numo::Int32] (shape: [n_samples]) True labels for testing data.
26
- # @return [Float] Mean accuracy
27
- def score(x, y)
28
- x = check_convert_sample_array(x)
29
- y = check_convert_label_array(y)
30
- check_sample_label_size(x, y)
31
- evaluator = Rumale::EvaluationMeasure::Accuracy.new
32
- evaluator.score(y, predict(x))
33
- end
34
- end
35
- end
36
- end
@@ -1,31 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'rumale/validation'
4
- require 'rumale/evaluation_measure/purity'
5
-
6
- module Rumale
7
- module Base
8
- # Module for all clustering algorithms in Rumale.
9
- module ClusterAnalyzer
10
- include Validation
11
-
12
- # An abstract method for analyzing clusters and predicting cluster indices.
13
- def fit_predict
14
- raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
15
- end
16
-
17
- # Calculate purity of clustering result.
18
- #
19
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) Testing data.
20
- # @param y [Numo::Int32] (shape: [n_samples]) True labels for testing data.
21
- # @return [Float] Purity
22
- def score(x, y)
23
- x = check_convert_sample_array(x)
24
- y = check_convert_label_array(y)
25
- check_sample_label_size(x, y)
26
- evaluator = Rumale::EvaluationMeasure::Purity.new
27
- evaluator.score(y, fit_predict(x))
28
- end
29
- end
30
- end
31
- end
@@ -1,17 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'rumale/validation'
4
-
5
- module Rumale
6
- module Base
7
- # Module for all evaluation measures in Rumale.
8
- module Evaluator
9
- include Validation
10
-
11
- # An abstract method for evaluation of model.
12
- def score
13
- raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
14
- end
15
- end
16
- end
17
- end
@@ -1,36 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'rumale/validation'
4
- require 'rumale/evaluation_measure/r2_score'
5
-
6
- module Rumale
7
- module Base
8
- # Module for all regressors in Rumale.
9
- module Regressor
10
- include Validation
11
-
12
- # An abstract method for fitting a model.
13
- def fit
14
- raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
15
- end
16
-
17
- # An abstract method for predicting labels.
18
- def predict
19
- raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
20
- end
21
-
22
- # Calculate the coefficient of determination for the given testing data.
23
- #
24
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) Testing data.
25
- # @param y [Numo::DFloat] (shape: [n_samples, n_outputs]) Target values for testing data.
26
- # @return [Float] Coefficient of determination
27
- def score(x, y)
28
- x = check_convert_sample_array(x)
29
- y = check_convert_tvalue_array(y)
30
- check_sample_tvalue_size(x, y)
31
- evaluator = Rumale::EvaluationMeasure::R2Score.new
32
- evaluator.score(y, predict(x))
33
- end
34
- end
35
- end
36
- end
@@ -1,21 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'rumale/validation'
4
-
5
- module Rumale
6
- module Base
7
- # Module for all validation methods in Rumale.
8
- module Splitter
9
- include Validation
10
-
11
- # Return the number of splits.
12
- # @return [Integer]
13
- attr_reader :n_splits
14
-
15
- # An abstract method for splitting dataset.
16
- def split
17
- raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
18
- end
19
- end
20
- end
21
- end
@@ -1,22 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'rumale/validation'
4
-
5
- module Rumale
6
- module Base
7
- # Module for all transfomers in Rumale.
8
- module Transformer
9
- include Validation
10
-
11
- # An abstract method for fitting a model.
12
- def fit
13
- raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
14
- end
15
-
16
- # An abstract method for fitting a model and transforming given data.
17
- def fit_transform
18
- raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
19
- end
20
- end
21
- end
22
- end
@@ -1,123 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'rumale/base/base_estimator'
4
- require 'rumale/base/cluster_analyzer'
5
- require 'rumale/pairwise_metric'
6
-
7
- module Rumale
8
- module Clustering
9
- # DBSCAN is a class that implements DBSCAN cluster analysis.
10
- #
11
- # @example
12
- # analyzer = Rumale::Clustering::DBSCAN.new(eps: 0.5, min_samples: 5)
13
- # cluster_labels = analyzer.fit_predict(samples)
14
- #
15
- # *Reference*
16
- # - Ester, M., Kriegel, H-P., Sander, J., and Xu, X., "A density-based algorithm for discovering clusters in large spatial databases with noise," Proc. KDD' 96, pp. 266--231, 1996.
17
- class DBSCAN
18
- include Base::BaseEstimator
19
- include Base::ClusterAnalyzer
20
-
21
- # Return the core sample indices.
22
- # @return [Numo::Int32] (shape: [n_core_samples])
23
- attr_reader :core_sample_ids
24
-
25
- # Return the cluster labels. The negative cluster label indicates that the point is noise.
26
- # @return [Numo::Int32] (shape: [n_samples])
27
- attr_reader :labels
28
-
29
- # Create a new cluster analyzer with DBSCAN method.
30
- #
31
- # @param eps [Float] The radius of neighborhood.
32
- # @param min_samples [Integer] The number of neighbor samples to be used for the criterion whether a point is a core point.
33
- # @param metric [String] The metric to calculate the distances.
34
- # If metric is 'euclidean', Euclidean distance is calculated for distance between points.
35
- # If metric is 'precomputed', the fit and fit_transform methods expect to be given a distance matrix.
36
- def initialize(eps: 0.5, min_samples: 5, metric: 'euclidean')
37
- check_params_numeric(eps: eps, min_samples: min_samples)
38
- check_params_string(metric: metric)
39
- @params = {}
40
- @params[:eps] = eps
41
- @params[:min_samples] = min_samples
42
- @params[:metric] = metric == 'precomputed' ? 'precomputed' : 'euclidean'
43
- @core_sample_ids = nil
44
- @labels = nil
45
- end
46
-
47
- # Analysis clusters with given training data.
48
- #
49
- # @overload fit(x) -> DBSCAN
50
- #
51
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
52
- # If the metric is 'precomputed', x must be a square distance matrix (shape: [n_samples, n_samples]).
53
- # @return [DBSCAN] The learned cluster analyzer itself.
54
- def fit(x, _y = nil)
55
- x = check_convert_sample_array(x)
56
- raise ArgumentError, 'Expect the input distance matrix to be square.' if @params[:metric] == 'precomputed' && x.shape[0] != x.shape[1]
57
-
58
- partial_fit(x)
59
- self
60
- end
61
-
62
- # Analysis clusters and assign samples to clusters.
63
- #
64
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to be used for cluster analysis.
65
- # If the metric is 'precomputed', x must be a square distance matrix (shape: [n_samples, n_samples]).
66
- # @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
67
- def fit_predict(x)
68
- x = check_convert_sample_array(x)
69
- raise ArgumentError, 'Expect the input distance matrix to be square.' if @params[:metric] == 'precomputed' && x.shape[0] != x.shape[1]
70
-
71
- partial_fit(x)
72
- labels
73
- end
74
-
75
- private
76
-
77
- def partial_fit(x)
78
- cluster_id = 0
79
- metric_mat = calc_pairwise_metrics(x)
80
- n_samples = metric_mat.shape[0]
81
- @core_sample_ids = []
82
- @labels = Numo::Int32.zeros(n_samples) - 2
83
- n_samples.times do |query_id|
84
- next if @labels[query_id] >= -1
85
-
86
- cluster_id += 1 if expand_cluster(metric_mat, query_id, cluster_id)
87
- end
88
- @core_sample_ids = Numo::Int32[*@core_sample_ids.flatten]
89
- nil
90
- end
91
-
92
- def calc_pairwise_metrics(x)
93
- @params[:metric] == 'precomputed' ? x : Rumale::PairwiseMetric.euclidean_distance(x)
94
- end
95
-
96
- def expand_cluster(metric_mat, query_id, cluster_id)
97
- target_ids = region_query(metric_mat[query_id, true])
98
- if target_ids.size < @params[:min_samples]
99
- @labels[query_id] = -1
100
- false
101
- else
102
- @labels[target_ids] = cluster_id
103
- @core_sample_ids.push(target_ids.dup)
104
- target_ids.delete(query_id)
105
- while (m = target_ids.shift)
106
- neighbor_ids = region_query(metric_mat[m, true])
107
- next if neighbor_ids.size < @params[:min_samples]
108
-
109
- neighbor_ids.each do |n|
110
- target_ids.push(n) if @labels[n] < -1
111
- @labels[n] = cluster_id if @labels[n] <= -1
112
- end
113
- end
114
- true
115
- end
116
- end
117
-
118
- def region_query(metric_arr)
119
- metric_arr.lt(@params[:eps]).where.to_a
120
- end
121
- end
122
- end
123
- end
@@ -1,218 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'rumale/base/base_estimator'
4
- require 'rumale/base/cluster_analyzer'
5
- require 'rumale/preprocessing/label_binarizer'
6
-
7
- module Rumale
8
- module Clustering
9
- # GaussianMixture is a class that implements cluster analysis with gaussian mixture model.
10
- #
11
- # @example
12
- # analyzer = Rumale::Clustering::GaussianMixture.new(n_clusters: 10, max_iter: 50)
13
- # cluster_labels = analyzer.fit_predict(samples)
14
- #
15
- # # If Numo::Linalg is installed, you can specify 'full' for the tyep of covariance option.
16
- # require 'numo/linalg/autoloader'
17
- # analyzer = Rumale::Clustering::GaussianMixture.new(n_clusters: 10, max_iter: 50, covariance_type: 'full')
18
- # cluster_labels = analyzer.fit_predict(samples)
19
- #
20
- class GaussianMixture
21
- include Base::BaseEstimator
22
- include Base::ClusterAnalyzer
23
-
24
- # Return the number of iterations to covergence.
25
- # @return [Integer]
26
- attr_reader :n_iter
27
-
28
- # Return the weight of each cluster.
29
- # @return [Numo::DFloat] (shape: [n_clusters])
30
- attr_reader :weights
31
-
32
- # Return the mean of each cluster.
33
- # @return [Numo::DFloat] (shape: [n_clusters, n_features])
34
- attr_reader :means
35
-
36
- # Return the diagonal elements of covariance matrix of each cluster.
37
- # @return [Numo::DFloat] (shape: [n_clusters, n_features] if 'diag', [n_clusters, n_features, n_features] if 'full')
38
- attr_reader :covariances
39
-
40
- # Create a new cluster analyzer with gaussian mixture model.
41
- #
42
- # @param n_clusters [Integer] The number of clusters.
43
- # @param init [String] The initialization method for centroids ('random' or 'k-means++').
44
- # @param covariance_type [String] The type of covariance parameter to be used ('diag' or 'full').
45
- # @param max_iter [Integer] The maximum number of iterations.
46
- # @param tol [Float] The tolerance of termination criterion.
47
- # @param reg_covar [Float] The non-negative regularization to the diagonal of covariance.
48
- # @param random_seed [Integer] The seed value using to initialize the random generator.
49
- def initialize(n_clusters: 8, init: 'k-means++', covariance_type: 'diag', max_iter: 50, tol: 1.0e-4, reg_covar: 1.0e-6, random_seed: nil)
50
- check_params_numeric(n_clusters: n_clusters, max_iter: max_iter, tol: tol)
51
- check_params_string(init: init)
52
- check_params_numeric_or_nil(random_seed: random_seed)
53
- check_params_positive(n_clusters: n_clusters, max_iter: max_iter)
54
- @params = {}
55
- @params[:n_clusters] = n_clusters
56
- @params[:init] = init == 'random' ? 'random' : 'k-means++'
57
- @params[:covariance_type] = covariance_type == 'full' ? 'full' : 'diag'
58
- @params[:max_iter] = max_iter
59
- @params[:tol] = tol
60
- @params[:reg_covar] = reg_covar
61
- @params[:random_seed] = random_seed
62
- @params[:random_seed] ||= srand
63
- @n_iter = nil
64
- @weights = nil
65
- @means = nil
66
- @covariances = nil
67
- end
68
-
69
- # Analysis clusters with given training data.
70
- #
71
- # @overload fit(x) -> GaussianMixture
72
- #
73
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
74
- # @return [GaussianMixture] The learned cluster analyzer itself.
75
- def fit(x, _y = nil)
76
- x = check_convert_sample_array(x)
77
- check_enable_linalg('fit')
78
-
79
- n_samples = x.shape[0]
80
- memberships = init_memberships(x)
81
- @params[:max_iter].times do |t|
82
- @n_iter = t
83
- @weights = calc_weights(n_samples, memberships)
84
- @means = calc_means(x, memberships)
85
- @covariances = calc_covariances(x, @means, memberships, @params[:reg_covar], @params[:covariance_type])
86
- new_memberships = calc_memberships(x, @weights, @means, @covariances, @params[:covariance_type])
87
- error = (memberships - new_memberships).abs.max
88
- break if error <= @params[:tol]
89
-
90
- memberships = new_memberships.dup
91
- end
92
- self
93
- end
94
-
95
- # Predict cluster labels for samples.
96
- #
97
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the cluster label.
98
- # @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
99
- def predict(x)
100
- x = check_convert_sample_array(x)
101
- check_enable_linalg('predict')
102
-
103
- memberships = calc_memberships(x, @weights, @means, @covariances, @params[:covariance_type])
104
- assign_cluster(memberships)
105
- end
106
-
107
- # Analysis clusters and assign samples to clusters.
108
- #
109
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
110
- # @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
111
- def fit_predict(x)
112
- x = check_convert_sample_array(x)
113
- check_enable_linalg('fit_predict')
114
-
115
- fit(x).predict(x)
116
- end
117
-
118
- private
119
-
120
- def assign_cluster(memberships)
121
- n_clusters = memberships.shape[1]
122
- memberships.max_index(axis: 1) - Numo::Int32[*0.step(memberships.size - 1, n_clusters)]
123
- end
124
-
125
- def init_memberships(x)
126
- kmeans = Rumale::Clustering::KMeans.new(
127
- n_clusters: @params[:n_clusters], init: @params[:init], max_iter: 0, random_seed: @params[:random_seed]
128
- )
129
- cluster_ids = kmeans.fit_predict(x)
130
- encoder = Rumale::Preprocessing::LabelBinarizer.new
131
- Numo::DFloat.cast(encoder.fit_transform(cluster_ids))
132
- end
133
-
134
- def calc_memberships(x, weights, means, covars, covar_type)
135
- n_samples = x.shape[0]
136
- n_clusters = means.shape[0]
137
- memberships = Numo::DFloat.zeros(n_samples, n_clusters)
138
- n_clusters.times do |n|
139
- centered = x - means[n, true]
140
- covar = covar_type == 'full' ? covars[n, true, true] : covars[n, true]
141
- memberships[true, n] = calc_unnormalized_membership(centered, weights[n], covar, covar_type)
142
- end
143
- memberships / memberships.sum(1).expand_dims(1)
144
- end
145
-
146
- def calc_weights(n_samples, memberships)
147
- memberships.sum(0) / n_samples
148
- end
149
-
150
- def calc_means(x, memberships)
151
- memberships.transpose.dot(x) / memberships.sum(0).expand_dims(1)
152
- end
153
-
154
- def calc_covariances(x, means, memberships, reg_cover, covar_type)
155
- if covar_type == 'full'
156
- calc_full_covariances(x, means, reg_cover, memberships)
157
- else
158
- calc_diag_covariances(x, means, reg_cover, memberships)
159
- end
160
- end
161
-
162
- def calc_diag_covariances(x, means, reg_cover, memberships)
163
- n_clusters = means.shape[0]
164
- diag_cov = Array.new(n_clusters) do |n|
165
- centered = x - means[n, true]
166
- memberships[true, n].dot(centered**2) / memberships[true, n].sum
167
- end
168
- Numo::DFloat.asarray(diag_cov) + reg_cover
169
- end
170
-
171
- def calc_full_covariances(x, means, reg_cover, memberships)
172
- n_features = x.shape[1]
173
- n_clusters = means.shape[0]
174
- cov_mats = Numo::DFloat.zeros(n_clusters, n_features, n_features)
175
- reg_mat = Numo::DFloat.eye(n_features) * reg_cover
176
- n_clusters.times do |n|
177
- centered = x - means[n, true]
178
- members = memberships[true, n]
179
- cov_mats[n, true, true] = reg_mat + (centered.transpose * members).dot(centered) / members.sum
180
- end
181
- cov_mats
182
- end
183
-
184
- def calc_unnormalized_membership(centered, weight, covar, covar_type)
185
- inv_covar = calc_inv_covariance(covar, covar_type)
186
- inv_sqrt_det_covar = calc_inv_sqrt_det_covariance(covar, covar_type)
187
- distances = if covar_type == 'full'
188
- (centered.dot(inv_covar) * centered).sum(1)
189
- else
190
- (centered * inv_covar * centered).sum(1)
191
- end
192
- weight * inv_sqrt_det_covar * Numo::NMath.exp(-0.5 * distances)
193
- end
194
-
195
- def calc_inv_covariance(covar, covar_type)
196
- if covar_type == 'full'
197
- Numo::Linalg.inv(covar)
198
- else
199
- 1.0 / covar
200
- end
201
- end
202
-
203
- def calc_inv_sqrt_det_covariance(covar, covar_type)
204
- if covar_type == 'full'
205
- 1.0 / Math.sqrt(Numo::Linalg.det(covar))
206
- else
207
- 1.0 / Math.sqrt(covar.prod)
208
- end
209
- end
210
-
211
- def check_enable_linalg(method_name)
212
- return unless @params[:covariance_type] == 'full' && !enable_linalg?
213
-
214
- raise "GaussianMixture##{method_name} requires Numo::Linalg when covariance_type is 'full' but that is not loaded."
215
- end
216
- end
217
- end
218
- end