rumale 0.23.3 → 0.24.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (142) hide show
  1. checksums.yaml +4 -4
  2. data/LICENSE.txt +5 -1
  3. data/README.md +3 -288
  4. data/lib/rumale/version.rb +1 -1
  5. data/lib/rumale.rb +20 -131
  6. metadata +252 -150
  7. data/CHANGELOG.md +0 -643
  8. data/CODE_OF_CONDUCT.md +0 -74
  9. data/ext/rumale/extconf.rb +0 -37
  10. data/ext/rumale/rumaleext.c +0 -545
  11. data/ext/rumale/rumaleext.h +0 -12
  12. data/lib/rumale/base/base_estimator.rb +0 -49
  13. data/lib/rumale/base/classifier.rb +0 -36
  14. data/lib/rumale/base/cluster_analyzer.rb +0 -31
  15. data/lib/rumale/base/evaluator.rb +0 -17
  16. data/lib/rumale/base/regressor.rb +0 -36
  17. data/lib/rumale/base/splitter.rb +0 -21
  18. data/lib/rumale/base/transformer.rb +0 -22
  19. data/lib/rumale/clustering/dbscan.rb +0 -123
  20. data/lib/rumale/clustering/gaussian_mixture.rb +0 -218
  21. data/lib/rumale/clustering/hdbscan.rb +0 -291
  22. data/lib/rumale/clustering/k_means.rb +0 -122
  23. data/lib/rumale/clustering/k_medoids.rb +0 -141
  24. data/lib/rumale/clustering/mini_batch_k_means.rb +0 -139
  25. data/lib/rumale/clustering/power_iteration.rb +0 -127
  26. data/lib/rumale/clustering/single_linkage.rb +0 -203
  27. data/lib/rumale/clustering/snn.rb +0 -76
  28. data/lib/rumale/clustering/spectral_clustering.rb +0 -115
  29. data/lib/rumale/dataset.rb +0 -246
  30. data/lib/rumale/decomposition/factor_analysis.rb +0 -150
  31. data/lib/rumale/decomposition/fast_ica.rb +0 -188
  32. data/lib/rumale/decomposition/nmf.rb +0 -124
  33. data/lib/rumale/decomposition/pca.rb +0 -159
  34. data/lib/rumale/ensemble/ada_boost_classifier.rb +0 -179
  35. data/lib/rumale/ensemble/ada_boost_regressor.rb +0 -160
  36. data/lib/rumale/ensemble/extra_trees_classifier.rb +0 -139
  37. data/lib/rumale/ensemble/extra_trees_regressor.rb +0 -125
  38. data/lib/rumale/ensemble/gradient_boosting_classifier.rb +0 -306
  39. data/lib/rumale/ensemble/gradient_boosting_regressor.rb +0 -237
  40. data/lib/rumale/ensemble/random_forest_classifier.rb +0 -189
  41. data/lib/rumale/ensemble/random_forest_regressor.rb +0 -153
  42. data/lib/rumale/ensemble/stacking_classifier.rb +0 -215
  43. data/lib/rumale/ensemble/stacking_regressor.rb +0 -163
  44. data/lib/rumale/ensemble/voting_classifier.rb +0 -126
  45. data/lib/rumale/ensemble/voting_regressor.rb +0 -82
  46. data/lib/rumale/evaluation_measure/accuracy.rb +0 -29
  47. data/lib/rumale/evaluation_measure/adjusted_rand_score.rb +0 -74
  48. data/lib/rumale/evaluation_measure/calinski_harabasz_score.rb +0 -56
  49. data/lib/rumale/evaluation_measure/davies_bouldin_score.rb +0 -53
  50. data/lib/rumale/evaluation_measure/explained_variance_score.rb +0 -39
  51. data/lib/rumale/evaluation_measure/f_score.rb +0 -50
  52. data/lib/rumale/evaluation_measure/function.rb +0 -147
  53. data/lib/rumale/evaluation_measure/log_loss.rb +0 -45
  54. data/lib/rumale/evaluation_measure/mean_absolute_error.rb +0 -29
  55. data/lib/rumale/evaluation_measure/mean_squared_error.rb +0 -29
  56. data/lib/rumale/evaluation_measure/mean_squared_log_error.rb +0 -29
  57. data/lib/rumale/evaluation_measure/median_absolute_error.rb +0 -30
  58. data/lib/rumale/evaluation_measure/mutual_information.rb +0 -49
  59. data/lib/rumale/evaluation_measure/normalized_mutual_information.rb +0 -53
  60. data/lib/rumale/evaluation_measure/precision.rb +0 -50
  61. data/lib/rumale/evaluation_measure/precision_recall.rb +0 -96
  62. data/lib/rumale/evaluation_measure/purity.rb +0 -40
  63. data/lib/rumale/evaluation_measure/r2_score.rb +0 -43
  64. data/lib/rumale/evaluation_measure/recall.rb +0 -50
  65. data/lib/rumale/evaluation_measure/roc_auc.rb +0 -130
  66. data/lib/rumale/evaluation_measure/silhouette_score.rb +0 -82
  67. data/lib/rumale/feature_extraction/feature_hasher.rb +0 -110
  68. data/lib/rumale/feature_extraction/hash_vectorizer.rb +0 -155
  69. data/lib/rumale/feature_extraction/tfidf_transformer.rb +0 -113
  70. data/lib/rumale/kernel_approximation/nystroem.rb +0 -126
  71. data/lib/rumale/kernel_approximation/rbf.rb +0 -102
  72. data/lib/rumale/kernel_machine/kernel_fda.rb +0 -120
  73. data/lib/rumale/kernel_machine/kernel_pca.rb +0 -97
  74. data/lib/rumale/kernel_machine/kernel_ridge.rb +0 -82
  75. data/lib/rumale/kernel_machine/kernel_ridge_classifier.rb +0 -92
  76. data/lib/rumale/kernel_machine/kernel_svc.rb +0 -193
  77. data/lib/rumale/linear_model/base_sgd.rb +0 -285
  78. data/lib/rumale/linear_model/elastic_net.rb +0 -119
  79. data/lib/rumale/linear_model/lasso.rb +0 -115
  80. data/lib/rumale/linear_model/linear_regression.rb +0 -201
  81. data/lib/rumale/linear_model/logistic_regression.rb +0 -275
  82. data/lib/rumale/linear_model/nnls.rb +0 -137
  83. data/lib/rumale/linear_model/ridge.rb +0 -209
  84. data/lib/rumale/linear_model/svc.rb +0 -213
  85. data/lib/rumale/linear_model/svr.rb +0 -132
  86. data/lib/rumale/manifold/mds.rb +0 -155
  87. data/lib/rumale/manifold/tsne.rb +0 -222
  88. data/lib/rumale/metric_learning/fisher_discriminant_analysis.rb +0 -113
  89. data/lib/rumale/metric_learning/mlkr.rb +0 -161
  90. data/lib/rumale/metric_learning/neighbourhood_component_analysis.rb +0 -167
  91. data/lib/rumale/model_selection/cross_validation.rb +0 -125
  92. data/lib/rumale/model_selection/function.rb +0 -42
  93. data/lib/rumale/model_selection/grid_search_cv.rb +0 -225
  94. data/lib/rumale/model_selection/group_k_fold.rb +0 -93
  95. data/lib/rumale/model_selection/group_shuffle_split.rb +0 -115
  96. data/lib/rumale/model_selection/k_fold.rb +0 -81
  97. data/lib/rumale/model_selection/shuffle_split.rb +0 -90
  98. data/lib/rumale/model_selection/stratified_k_fold.rb +0 -99
  99. data/lib/rumale/model_selection/stratified_shuffle_split.rb +0 -118
  100. data/lib/rumale/model_selection/time_series_split.rb +0 -91
  101. data/lib/rumale/multiclass/one_vs_rest_classifier.rb +0 -83
  102. data/lib/rumale/naive_bayes/base_naive_bayes.rb +0 -47
  103. data/lib/rumale/naive_bayes/bernoulli_nb.rb +0 -82
  104. data/lib/rumale/naive_bayes/complement_nb.rb +0 -85
  105. data/lib/rumale/naive_bayes/gaussian_nb.rb +0 -69
  106. data/lib/rumale/naive_bayes/multinomial_nb.rb +0 -74
  107. data/lib/rumale/naive_bayes/negation_nb.rb +0 -71
  108. data/lib/rumale/nearest_neighbors/k_neighbors_classifier.rb +0 -133
  109. data/lib/rumale/nearest_neighbors/k_neighbors_regressor.rb +0 -108
  110. data/lib/rumale/nearest_neighbors/vp_tree.rb +0 -132
  111. data/lib/rumale/neural_network/adam.rb +0 -56
  112. data/lib/rumale/neural_network/base_mlp.rb +0 -248
  113. data/lib/rumale/neural_network/mlp_classifier.rb +0 -120
  114. data/lib/rumale/neural_network/mlp_regressor.rb +0 -90
  115. data/lib/rumale/pairwise_metric.rb +0 -152
  116. data/lib/rumale/pipeline/feature_union.rb +0 -69
  117. data/lib/rumale/pipeline/pipeline.rb +0 -175
  118. data/lib/rumale/preprocessing/bin_discretizer.rb +0 -93
  119. data/lib/rumale/preprocessing/binarizer.rb +0 -60
  120. data/lib/rumale/preprocessing/kernel_calculator.rb +0 -92
  121. data/lib/rumale/preprocessing/l1_normalizer.rb +0 -62
  122. data/lib/rumale/preprocessing/l2_normalizer.rb +0 -63
  123. data/lib/rumale/preprocessing/label_binarizer.rb +0 -89
  124. data/lib/rumale/preprocessing/label_encoder.rb +0 -79
  125. data/lib/rumale/preprocessing/max_abs_scaler.rb +0 -61
  126. data/lib/rumale/preprocessing/max_normalizer.rb +0 -62
  127. data/lib/rumale/preprocessing/min_max_scaler.rb +0 -76
  128. data/lib/rumale/preprocessing/one_hot_encoder.rb +0 -100
  129. data/lib/rumale/preprocessing/ordinal_encoder.rb +0 -109
  130. data/lib/rumale/preprocessing/polynomial_features.rb +0 -109
  131. data/lib/rumale/preprocessing/standard_scaler.rb +0 -71
  132. data/lib/rumale/probabilistic_output.rb +0 -114
  133. data/lib/rumale/tree/base_decision_tree.rb +0 -150
  134. data/lib/rumale/tree/decision_tree_classifier.rb +0 -150
  135. data/lib/rumale/tree/decision_tree_regressor.rb +0 -116
  136. data/lib/rumale/tree/extra_tree_classifier.rb +0 -107
  137. data/lib/rumale/tree/extra_tree_regressor.rb +0 -94
  138. data/lib/rumale/tree/gradient_tree_regressor.rb +0 -202
  139. data/lib/rumale/tree/node.rb +0 -39
  140. data/lib/rumale/utils.rb +0 -42
  141. data/lib/rumale/validation.rb +0 -128
  142. data/lib/rumale/values.rb +0 -13
@@ -1,49 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- module Rumale
4
- # This module consists of basic mix-in classes.
5
- module Base
6
- # Base module for all estimators in Rumale.
7
- module BaseEstimator
8
- # Return parameters about an estimator.
9
- # @return [Hash]
10
- attr_reader :params
11
-
12
- private
13
-
14
- def enable_linalg?(warning: true)
15
- if defined?(Numo::Linalg).nil?
16
- warn('If you want to use features that depend on Numo::Linalg, you should install and load Numo::Linalg in advance.') if warning
17
- return false
18
- end
19
- if Numo::Linalg::VERSION < '0.1.4'
20
- if warning
21
- warn('The loaded Numo::Linalg does not implement the methods required by Rumale. Please load Numo::Linalg version 0.1.4 or later.')
22
- end
23
- return false
24
- end
25
- true
26
- end
27
-
28
- def enable_parallel?
29
- return false if @params[:n_jobs].nil?
30
-
31
- if defined?(Parallel).nil?
32
- warn('If you want to use parallel option, you should install and load Parallel in advance.')
33
- return false
34
- end
35
- true
36
- end
37
-
38
- def n_processes
39
- return 1 unless enable_parallel?
40
-
41
- @params[:n_jobs] <= 0 ? Parallel.processor_count : @params[:n_jobs]
42
- end
43
-
44
- def parallel_map(n_outputs, &block)
45
- Parallel.map(Array.new(n_outputs) { |v| v }, in_processes: n_processes, &block)
46
- end
47
- end
48
- end
49
- end
@@ -1,36 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'rumale/validation'
4
- require 'rumale/evaluation_measure/accuracy'
5
-
6
- module Rumale
7
- module Base
8
- # Module for all classifiers in Rumale.
9
- module Classifier
10
- include Validation
11
-
12
- # An abstract method for fitting a model.
13
- def fit
14
- raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
15
- end
16
-
17
- # An abstract method for predicting labels.
18
- def predict
19
- raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
20
- end
21
-
22
- # Calculate the mean accuracy of the given testing data.
23
- #
24
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) Testing data.
25
- # @param y [Numo::Int32] (shape: [n_samples]) True labels for testing data.
26
- # @return [Float] Mean accuracy
27
- def score(x, y)
28
- x = check_convert_sample_array(x)
29
- y = check_convert_label_array(y)
30
- check_sample_label_size(x, y)
31
- evaluator = Rumale::EvaluationMeasure::Accuracy.new
32
- evaluator.score(y, predict(x))
33
- end
34
- end
35
- end
36
- end
@@ -1,31 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'rumale/validation'
4
- require 'rumale/evaluation_measure/purity'
5
-
6
- module Rumale
7
- module Base
8
- # Module for all clustering algorithms in Rumale.
9
- module ClusterAnalyzer
10
- include Validation
11
-
12
- # An abstract method for analyzing clusters and predicting cluster indices.
13
- def fit_predict
14
- raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
15
- end
16
-
17
- # Calculate purity of clustering result.
18
- #
19
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) Testing data.
20
- # @param y [Numo::Int32] (shape: [n_samples]) True labels for testing data.
21
- # @return [Float] Purity
22
- def score(x, y)
23
- x = check_convert_sample_array(x)
24
- y = check_convert_label_array(y)
25
- check_sample_label_size(x, y)
26
- evaluator = Rumale::EvaluationMeasure::Purity.new
27
- evaluator.score(y, fit_predict(x))
28
- end
29
- end
30
- end
31
- end
@@ -1,17 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'rumale/validation'
4
-
5
- module Rumale
6
- module Base
7
- # Module for all evaluation measures in Rumale.
8
- module Evaluator
9
- include Validation
10
-
11
- # An abstract method for evaluation of model.
12
- def score
13
- raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
14
- end
15
- end
16
- end
17
- end
@@ -1,36 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'rumale/validation'
4
- require 'rumale/evaluation_measure/r2_score'
5
-
6
- module Rumale
7
- module Base
8
- # Module for all regressors in Rumale.
9
- module Regressor
10
- include Validation
11
-
12
- # An abstract method for fitting a model.
13
- def fit
14
- raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
15
- end
16
-
17
- # An abstract method for predicting labels.
18
- def predict
19
- raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
20
- end
21
-
22
- # Calculate the coefficient of determination for the given testing data.
23
- #
24
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) Testing data.
25
- # @param y [Numo::DFloat] (shape: [n_samples, n_outputs]) Target values for testing data.
26
- # @return [Float] Coefficient of determination
27
- def score(x, y)
28
- x = check_convert_sample_array(x)
29
- y = check_convert_tvalue_array(y)
30
- check_sample_tvalue_size(x, y)
31
- evaluator = Rumale::EvaluationMeasure::R2Score.new
32
- evaluator.score(y, predict(x))
33
- end
34
- end
35
- end
36
- end
@@ -1,21 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'rumale/validation'
4
-
5
- module Rumale
6
- module Base
7
- # Module for all validation methods in Rumale.
8
- module Splitter
9
- include Validation
10
-
11
- # Return the number of splits.
12
- # @return [Integer]
13
- attr_reader :n_splits
14
-
15
- # An abstract method for splitting dataset.
16
- def split
17
- raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
18
- end
19
- end
20
- end
21
- end
@@ -1,22 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'rumale/validation'
4
-
5
- module Rumale
6
- module Base
7
- # Module for all transfomers in Rumale.
8
- module Transformer
9
- include Validation
10
-
11
- # An abstract method for fitting a model.
12
- def fit
13
- raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
14
- end
15
-
16
- # An abstract method for fitting a model and transforming given data.
17
- def fit_transform
18
- raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
19
- end
20
- end
21
- end
22
- end
@@ -1,123 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'rumale/base/base_estimator'
4
- require 'rumale/base/cluster_analyzer'
5
- require 'rumale/pairwise_metric'
6
-
7
- module Rumale
8
- module Clustering
9
- # DBSCAN is a class that implements DBSCAN cluster analysis.
10
- #
11
- # @example
12
- # analyzer = Rumale::Clustering::DBSCAN.new(eps: 0.5, min_samples: 5)
13
- # cluster_labels = analyzer.fit_predict(samples)
14
- #
15
- # *Reference*
16
- # - Ester, M., Kriegel, H-P., Sander, J., and Xu, X., "A density-based algorithm for discovering clusters in large spatial databases with noise," Proc. KDD' 96, pp. 266--231, 1996.
17
- class DBSCAN
18
- include Base::BaseEstimator
19
- include Base::ClusterAnalyzer
20
-
21
- # Return the core sample indices.
22
- # @return [Numo::Int32] (shape: [n_core_samples])
23
- attr_reader :core_sample_ids
24
-
25
- # Return the cluster labels. The negative cluster label indicates that the point is noise.
26
- # @return [Numo::Int32] (shape: [n_samples])
27
- attr_reader :labels
28
-
29
- # Create a new cluster analyzer with DBSCAN method.
30
- #
31
- # @param eps [Float] The radius of neighborhood.
32
- # @param min_samples [Integer] The number of neighbor samples to be used for the criterion whether a point is a core point.
33
- # @param metric [String] The metric to calculate the distances.
34
- # If metric is 'euclidean', Euclidean distance is calculated for distance between points.
35
- # If metric is 'precomputed', the fit and fit_transform methods expect to be given a distance matrix.
36
- def initialize(eps: 0.5, min_samples: 5, metric: 'euclidean')
37
- check_params_numeric(eps: eps, min_samples: min_samples)
38
- check_params_string(metric: metric)
39
- @params = {}
40
- @params[:eps] = eps
41
- @params[:min_samples] = min_samples
42
- @params[:metric] = metric == 'precomputed' ? 'precomputed' : 'euclidean'
43
- @core_sample_ids = nil
44
- @labels = nil
45
- end
46
-
47
- # Analysis clusters with given training data.
48
- #
49
- # @overload fit(x) -> DBSCAN
50
- #
51
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
52
- # If the metric is 'precomputed', x must be a square distance matrix (shape: [n_samples, n_samples]).
53
- # @return [DBSCAN] The learned cluster analyzer itself.
54
- def fit(x, _y = nil)
55
- x = check_convert_sample_array(x)
56
- raise ArgumentError, 'Expect the input distance matrix to be square.' if @params[:metric] == 'precomputed' && x.shape[0] != x.shape[1]
57
-
58
- partial_fit(x)
59
- self
60
- end
61
-
62
- # Analysis clusters and assign samples to clusters.
63
- #
64
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to be used for cluster analysis.
65
- # If the metric is 'precomputed', x must be a square distance matrix (shape: [n_samples, n_samples]).
66
- # @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
67
- def fit_predict(x)
68
- x = check_convert_sample_array(x)
69
- raise ArgumentError, 'Expect the input distance matrix to be square.' if @params[:metric] == 'precomputed' && x.shape[0] != x.shape[1]
70
-
71
- partial_fit(x)
72
- labels
73
- end
74
-
75
- private
76
-
77
- def partial_fit(x)
78
- cluster_id = 0
79
- metric_mat = calc_pairwise_metrics(x)
80
- n_samples = metric_mat.shape[0]
81
- @core_sample_ids = []
82
- @labels = Numo::Int32.zeros(n_samples) - 2
83
- n_samples.times do |query_id|
84
- next if @labels[query_id] >= -1
85
-
86
- cluster_id += 1 if expand_cluster(metric_mat, query_id, cluster_id)
87
- end
88
- @core_sample_ids = Numo::Int32[*@core_sample_ids.flatten]
89
- nil
90
- end
91
-
92
- def calc_pairwise_metrics(x)
93
- @params[:metric] == 'precomputed' ? x : Rumale::PairwiseMetric.euclidean_distance(x)
94
- end
95
-
96
- def expand_cluster(metric_mat, query_id, cluster_id)
97
- target_ids = region_query(metric_mat[query_id, true])
98
- if target_ids.size < @params[:min_samples]
99
- @labels[query_id] = -1
100
- false
101
- else
102
- @labels[target_ids] = cluster_id
103
- @core_sample_ids.push(target_ids.dup)
104
- target_ids.delete(query_id)
105
- while (m = target_ids.shift)
106
- neighbor_ids = region_query(metric_mat[m, true])
107
- next if neighbor_ids.size < @params[:min_samples]
108
-
109
- neighbor_ids.each do |n|
110
- target_ids.push(n) if @labels[n] < -1
111
- @labels[n] = cluster_id if @labels[n] <= -1
112
- end
113
- end
114
- true
115
- end
116
- end
117
-
118
- def region_query(metric_arr)
119
- metric_arr.lt(@params[:eps]).where.to_a
120
- end
121
- end
122
- end
123
- end
@@ -1,218 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'rumale/base/base_estimator'
4
- require 'rumale/base/cluster_analyzer'
5
- require 'rumale/preprocessing/label_binarizer'
6
-
7
- module Rumale
8
- module Clustering
9
- # GaussianMixture is a class that implements cluster analysis with gaussian mixture model.
10
- #
11
- # @example
12
- # analyzer = Rumale::Clustering::GaussianMixture.new(n_clusters: 10, max_iter: 50)
13
- # cluster_labels = analyzer.fit_predict(samples)
14
- #
15
- # # If Numo::Linalg is installed, you can specify 'full' for the tyep of covariance option.
16
- # require 'numo/linalg/autoloader'
17
- # analyzer = Rumale::Clustering::GaussianMixture.new(n_clusters: 10, max_iter: 50, covariance_type: 'full')
18
- # cluster_labels = analyzer.fit_predict(samples)
19
- #
20
- class GaussianMixture
21
- include Base::BaseEstimator
22
- include Base::ClusterAnalyzer
23
-
24
- # Return the number of iterations to covergence.
25
- # @return [Integer]
26
- attr_reader :n_iter
27
-
28
- # Return the weight of each cluster.
29
- # @return [Numo::DFloat] (shape: [n_clusters])
30
- attr_reader :weights
31
-
32
- # Return the mean of each cluster.
33
- # @return [Numo::DFloat] (shape: [n_clusters, n_features])
34
- attr_reader :means
35
-
36
- # Return the diagonal elements of covariance matrix of each cluster.
37
- # @return [Numo::DFloat] (shape: [n_clusters, n_features] if 'diag', [n_clusters, n_features, n_features] if 'full')
38
- attr_reader :covariances
39
-
40
- # Create a new cluster analyzer with gaussian mixture model.
41
- #
42
- # @param n_clusters [Integer] The number of clusters.
43
- # @param init [String] The initialization method for centroids ('random' or 'k-means++').
44
- # @param covariance_type [String] The type of covariance parameter to be used ('diag' or 'full').
45
- # @param max_iter [Integer] The maximum number of iterations.
46
- # @param tol [Float] The tolerance of termination criterion.
47
- # @param reg_covar [Float] The non-negative regularization to the diagonal of covariance.
48
- # @param random_seed [Integer] The seed value using to initialize the random generator.
49
- def initialize(n_clusters: 8, init: 'k-means++', covariance_type: 'diag', max_iter: 50, tol: 1.0e-4, reg_covar: 1.0e-6, random_seed: nil)
50
- check_params_numeric(n_clusters: n_clusters, max_iter: max_iter, tol: tol)
51
- check_params_string(init: init)
52
- check_params_numeric_or_nil(random_seed: random_seed)
53
- check_params_positive(n_clusters: n_clusters, max_iter: max_iter)
54
- @params = {}
55
- @params[:n_clusters] = n_clusters
56
- @params[:init] = init == 'random' ? 'random' : 'k-means++'
57
- @params[:covariance_type] = covariance_type == 'full' ? 'full' : 'diag'
58
- @params[:max_iter] = max_iter
59
- @params[:tol] = tol
60
- @params[:reg_covar] = reg_covar
61
- @params[:random_seed] = random_seed
62
- @params[:random_seed] ||= srand
63
- @n_iter = nil
64
- @weights = nil
65
- @means = nil
66
- @covariances = nil
67
- end
68
-
69
- # Analysis clusters with given training data.
70
- #
71
- # @overload fit(x) -> GaussianMixture
72
- #
73
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
74
- # @return [GaussianMixture] The learned cluster analyzer itself.
75
- def fit(x, _y = nil)
76
- x = check_convert_sample_array(x)
77
- check_enable_linalg('fit')
78
-
79
- n_samples = x.shape[0]
80
- memberships = init_memberships(x)
81
- @params[:max_iter].times do |t|
82
- @n_iter = t
83
- @weights = calc_weights(n_samples, memberships)
84
- @means = calc_means(x, memberships)
85
- @covariances = calc_covariances(x, @means, memberships, @params[:reg_covar], @params[:covariance_type])
86
- new_memberships = calc_memberships(x, @weights, @means, @covariances, @params[:covariance_type])
87
- error = (memberships - new_memberships).abs.max
88
- break if error <= @params[:tol]
89
-
90
- memberships = new_memberships.dup
91
- end
92
- self
93
- end
94
-
95
- # Predict cluster labels for samples.
96
- #
97
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the cluster label.
98
- # @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
99
- def predict(x)
100
- x = check_convert_sample_array(x)
101
- check_enable_linalg('predict')
102
-
103
- memberships = calc_memberships(x, @weights, @means, @covariances, @params[:covariance_type])
104
- assign_cluster(memberships)
105
- end
106
-
107
- # Analysis clusters and assign samples to clusters.
108
- #
109
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
110
- # @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
111
- def fit_predict(x)
112
- x = check_convert_sample_array(x)
113
- check_enable_linalg('fit_predict')
114
-
115
- fit(x).predict(x)
116
- end
117
-
118
- private
119
-
120
- def assign_cluster(memberships)
121
- n_clusters = memberships.shape[1]
122
- memberships.max_index(axis: 1) - Numo::Int32[*0.step(memberships.size - 1, n_clusters)]
123
- end
124
-
125
- def init_memberships(x)
126
- kmeans = Rumale::Clustering::KMeans.new(
127
- n_clusters: @params[:n_clusters], init: @params[:init], max_iter: 0, random_seed: @params[:random_seed]
128
- )
129
- cluster_ids = kmeans.fit_predict(x)
130
- encoder = Rumale::Preprocessing::LabelBinarizer.new
131
- Numo::DFloat.cast(encoder.fit_transform(cluster_ids))
132
- end
133
-
134
- def calc_memberships(x, weights, means, covars, covar_type)
135
- n_samples = x.shape[0]
136
- n_clusters = means.shape[0]
137
- memberships = Numo::DFloat.zeros(n_samples, n_clusters)
138
- n_clusters.times do |n|
139
- centered = x - means[n, true]
140
- covar = covar_type == 'full' ? covars[n, true, true] : covars[n, true]
141
- memberships[true, n] = calc_unnormalized_membership(centered, weights[n], covar, covar_type)
142
- end
143
- memberships / memberships.sum(1).expand_dims(1)
144
- end
145
-
146
- def calc_weights(n_samples, memberships)
147
- memberships.sum(0) / n_samples
148
- end
149
-
150
- def calc_means(x, memberships)
151
- memberships.transpose.dot(x) / memberships.sum(0).expand_dims(1)
152
- end
153
-
154
- def calc_covariances(x, means, memberships, reg_cover, covar_type)
155
- if covar_type == 'full'
156
- calc_full_covariances(x, means, reg_cover, memberships)
157
- else
158
- calc_diag_covariances(x, means, reg_cover, memberships)
159
- end
160
- end
161
-
162
- def calc_diag_covariances(x, means, reg_cover, memberships)
163
- n_clusters = means.shape[0]
164
- diag_cov = Array.new(n_clusters) do |n|
165
- centered = x - means[n, true]
166
- memberships[true, n].dot(centered**2) / memberships[true, n].sum
167
- end
168
- Numo::DFloat.asarray(diag_cov) + reg_cover
169
- end
170
-
171
- def calc_full_covariances(x, means, reg_cover, memberships)
172
- n_features = x.shape[1]
173
- n_clusters = means.shape[0]
174
- cov_mats = Numo::DFloat.zeros(n_clusters, n_features, n_features)
175
- reg_mat = Numo::DFloat.eye(n_features) * reg_cover
176
- n_clusters.times do |n|
177
- centered = x - means[n, true]
178
- members = memberships[true, n]
179
- cov_mats[n, true, true] = reg_mat + (centered.transpose * members).dot(centered) / members.sum
180
- end
181
- cov_mats
182
- end
183
-
184
- def calc_unnormalized_membership(centered, weight, covar, covar_type)
185
- inv_covar = calc_inv_covariance(covar, covar_type)
186
- inv_sqrt_det_covar = calc_inv_sqrt_det_covariance(covar, covar_type)
187
- distances = if covar_type == 'full'
188
- (centered.dot(inv_covar) * centered).sum(1)
189
- else
190
- (centered * inv_covar * centered).sum(1)
191
- end
192
- weight * inv_sqrt_det_covar * Numo::NMath.exp(-0.5 * distances)
193
- end
194
-
195
- def calc_inv_covariance(covar, covar_type)
196
- if covar_type == 'full'
197
- Numo::Linalg.inv(covar)
198
- else
199
- 1.0 / covar
200
- end
201
- end
202
-
203
- def calc_inv_sqrt_det_covariance(covar, covar_type)
204
- if covar_type == 'full'
205
- 1.0 / Math.sqrt(Numo::Linalg.det(covar))
206
- else
207
- 1.0 / Math.sqrt(covar.prod)
208
- end
209
- end
210
-
211
- def check_enable_linalg(method_name)
212
- return unless @params[:covariance_type] == 'full' && !enable_linalg?
213
-
214
- raise "GaussianMixture##{method_name} requires Numo::Linalg when covariance_type is 'full' but that is not loaded."
215
- end
216
- end
217
- end
218
- end