rumale-clustering 0.24.0

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml ADDED
@@ -0,0 +1,7 @@
1
+ ---
2
+ SHA256:
3
+ metadata.gz: 26c8d431fa54beb0ef656cb5c058176ed8b777dcd1075d4ff859c37ca458ab98
4
+ data.tar.gz: e180764368160a0273fc42e92238beaa25e93ebbbee0766dfb9f0efed2bc80fe
5
+ SHA512:
6
+ metadata.gz: e5386f87dbed2376c712b9f1e74484f757d0bd6e89b8d1c5455865405f4561ae22f4245863ecc06894202e3bea7373f97c767cd2e182172931eb58c18ee47220
7
+ data.tar.gz: 52e855b335ea4454850ffc2ab18a2c89c34849bb88f0b59af59073071d803e926c69638241f47d320f40aa06d945b105887e4e9d84ef453d404a43c1825470a5
data/LICENSE.txt ADDED
@@ -0,0 +1,27 @@
1
+ Copyright (c) 2022 Atsushi Tatsuma
2
+ All rights reserved.
3
+
4
+ Redistribution and use in source and binary forms, with or without
5
+ modification, are permitted provided that the following conditions are met:
6
+
7
+ * Redistributions of source code must retain the above copyright notice, this
8
+ list of conditions and the following disclaimer.
9
+
10
+ * Redistributions in binary form must reproduce the above copyright notice,
11
+ this list of conditions and the following disclaimer in the documentation
12
+ and/or other materials provided with the distribution.
13
+
14
+ * Neither the name of the copyright holder nor the names of its
15
+ contributors may be used to endorse or promote products derived from
16
+ this software without specific prior written permission.
17
+
18
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
data/README.md ADDED
@@ -0,0 +1,34 @@
1
+ # Rumale::Clustering
2
+
3
+ [![Gem Version](https://badge.fury.io/rb/rumale-clustering.svg)](https://badge.fury.io/rb/rumale-clustering)
4
+ [![BSD 3-Clause License](https://img.shields.io/badge/License-BSD%203--Clause-orange.svg)](https://github.com/yoshoku/rumale/blob/main/rumale-clustering/LICENSE.txt)
5
+ [![Documentation](https://img.shields.io/badge/api-reference-blue.svg)](https://yoshoku.github.io/rumale/doc/Rumale/Clustering.html)
6
+
7
+ Rumale is a machine learning library in Ruby.
8
+ Rumale::Clustering provides cluster analysis algorithms,
9
+ such as K-Means, Gaussian Mixture Model, DBSCAN, and Spectral Clustering,
10
+ with Rumale interface.
11
+
12
+ ## Installation
13
+
14
+ Add this line to your application's Gemfile:
15
+
16
+ ```ruby
17
+ gem 'rumale-clustering'
18
+ ```
19
+
20
+ And then execute:
21
+
22
+ $ bundle install
23
+
24
+ Or install it yourself as:
25
+
26
+ $ gem install rumale-clustering
27
+
28
+ ## Documentation
29
+
30
+ - [Rumale API Documentation - Clustering](https://yoshoku.github.io/rumale/doc/Rumale/Clustering.html)
31
+
32
+ ## License
33
+
34
+ The gem is available as open source under the terms of the [BSD-3-Clause License](https://opensource.org/licenses/BSD-3-Clause).
@@ -0,0 +1,126 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/estimator'
4
+ require 'rumale/base/cluster_analyzer'
5
+ require 'rumale/pairwise_metric'
6
+ require 'rumale/validation'
7
+
8
+ module Rumale
9
+ module Clustering
10
+ # DBSCAN is a class that implements DBSCAN cluster analysis.
11
+ #
12
+ # @example
13
+ # require 'rumale/clustering/dbscan'
14
+ #
15
+ # analyzer = Rumale::Clustering::DBSCAN.new(eps: 0.5, min_samples: 5)
16
+ # cluster_labels = analyzer.fit_predict(samples)
17
+ #
18
+ # *Reference*
19
+ # - Ester, M., Kriegel, H-P., Sander, J., and Xu, X., "A density-based algorithm for discovering clusters in large spatial databases with noise," Proc. KDD' 96, pp. 266--231, 1996.
20
+ class DBSCAN < ::Rumale::Base::Estimator
21
+ include ::Rumale::Base::ClusterAnalyzer
22
+
23
+ # Return the core sample indices.
24
+ # @return [Numo::Int32] (shape: [n_core_samples])
25
+ attr_reader :core_sample_ids
26
+
27
+ # Return the cluster labels. The negative cluster label indicates that the point is noise.
28
+ # @return [Numo::Int32] (shape: [n_samples])
29
+ attr_reader :labels
30
+
31
+ # Create a new cluster analyzer with DBSCAN method.
32
+ #
33
+ # @param eps [Float] The radius of neighborhood.
34
+ # @param min_samples [Integer] The number of neighbor samples to be used for the criterion whether a point is a core point.
35
+ # @param metric [String] The metric to calculate the distances.
36
+ # If metric is 'euclidean', Euclidean distance is calculated for distance between points.
37
+ # If metric is 'precomputed', the fit and fit_transform methods expect to be given a distance matrix.
38
+ def initialize(eps: 0.5, min_samples: 5, metric: 'euclidean')
39
+ super()
40
+ @params = {
41
+ eps: eps,
42
+ min_samples: min_samples,
43
+ metric: (metric == 'precomputed' ? 'precomputed' : 'euclidean')
44
+ }
45
+ end
46
+
47
+ # Analysis clusters with given training data.
48
+ #
49
+ # @overload fit(x) -> DBSCAN
50
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
51
+ # If the metric is 'precomputed', x must be a square distance matrix (shape: [n_samples, n_samples]).
52
+ # @return [DBSCAN] The learned cluster analyzer itself.
53
+ def fit(x, _y = nil)
54
+ x = ::Rumale::Validation.check_convert_sample_array(x)
55
+ raise ArgumentError, 'the input distance matrix should be square' if check_invalid_array_shape(x)
56
+
57
+ partial_fit(x)
58
+ self
59
+ end
60
+
61
+ # Analysis clusters and assign samples to clusters.
62
+ #
63
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to be used for cluster analysis.
64
+ # If the metric is 'precomputed', x must be a square distance matrix (shape: [n_samples, n_samples]).
65
+ # @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
66
+ def fit_predict(x)
67
+ x = ::Rumale::Validation.check_convert_sample_array(x)
68
+ raise ArgumentError, 'the input distance matrix should be square' if check_invalid_array_shape(x)
69
+
70
+ partial_fit(x)
71
+ labels
72
+ end
73
+
74
+ private
75
+
76
+ def check_invalid_array_shape(x)
77
+ @params[:metric] == 'precomputed' && x.shape[0] != x.shape[1]
78
+ end
79
+
80
+ def partial_fit(x)
81
+ cluster_id = 0
82
+ metric_mat = calc_pairwise_metrics(x)
83
+ n_samples = metric_mat.shape[0]
84
+ @core_sample_ids = []
85
+ @labels = Numo::Int32.zeros(n_samples) - 2
86
+ n_samples.times do |query_id|
87
+ next if @labels[query_id] >= -1
88
+
89
+ cluster_id += 1 if expand_cluster(metric_mat, query_id, cluster_id)
90
+ end
91
+ @core_sample_ids = Numo::Int32[*@core_sample_ids.flatten]
92
+ nil
93
+ end
94
+
95
+ def calc_pairwise_metrics(x)
96
+ @params[:metric] == 'precomputed' ? x : ::Rumale::PairwiseMetric.euclidean_distance(x)
97
+ end
98
+
99
+ def expand_cluster(metric_mat, query_id, cluster_id)
100
+ target_ids = region_query(metric_mat[query_id, true])
101
+ if target_ids.size < @params[:min_samples]
102
+ @labels[query_id] = -1
103
+ false
104
+ else
105
+ @labels[target_ids] = cluster_id
106
+ @core_sample_ids.push(target_ids.dup)
107
+ target_ids.delete(query_id)
108
+ while (m = target_ids.shift)
109
+ neighbor_ids = region_query(metric_mat[m, true])
110
+ next if neighbor_ids.size < @params[:min_samples]
111
+
112
+ neighbor_ids.each do |n|
113
+ target_ids.push(n) if @labels[n] < -1
114
+ @labels[n] = cluster_id if @labels[n] <= -1
115
+ end
116
+ end
117
+ true
118
+ end
119
+ end
120
+
121
+ def region_query(metric_arr)
122
+ metric_arr.lt(@params[:eps]).where.to_a
123
+ end
124
+ end
125
+ end
126
+ end
@@ -0,0 +1,215 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/estimator'
4
+ require 'rumale/base/cluster_analyzer'
5
+ require 'rumale/utils'
6
+ require 'rumale/validation'
7
+ require 'rumale/clustering/k_means'
8
+
9
+ module Rumale
10
+ module Clustering
11
+ # GaussianMixture is a class that implements cluster analysis with gaussian mixture model.
12
+ #
13
+ # @example
14
+ # require 'rumale/clustering/gaussian_mixture'
15
+ #
16
+ # analyzer = Rumale::Clustering::GaussianMixture.new(n_clusters: 10, max_iter: 50)
17
+ # cluster_labels = analyzer.fit_predict(samples)
18
+ #
19
+ # # If Numo::Linalg is installed, you can specify 'full' for the tyep of covariance option.
20
+ # require 'numo/linalg/autoloader'
21
+ # require 'rumale/clustering/gaussian_mixture'
22
+ #
23
+ # analyzer = Rumale::Clustering::GaussianMixture.new(n_clusters: 10, max_iter: 50, covariance_type: 'full')
24
+ # cluster_labels = analyzer.fit_predict(samples)
25
+ #
26
+ class GaussianMixture < ::Rumale::Base::Estimator # rubocop:disable Metrics/ClassLength
27
+ include ::Rumale::Base::ClusterAnalyzer
28
+
29
+ # Return the number of iterations to covergence.
30
+ # @return [Integer]
31
+ attr_reader :n_iter
32
+
33
+ # Return the weight of each cluster.
34
+ # @return [Numo::DFloat] (shape: [n_clusters])
35
+ attr_reader :weights
36
+
37
+ # Return the mean of each cluster.
38
+ # @return [Numo::DFloat] (shape: [n_clusters, n_features])
39
+ attr_reader :means
40
+
41
+ # Return the diagonal elements of covariance matrix of each cluster.
42
+ # @return [Numo::DFloat] (shape: [n_clusters, n_features] if 'diag', [n_clusters, n_features, n_features] if 'full')
43
+ attr_reader :covariances
44
+
45
+ # Create a new cluster analyzer with gaussian mixture model.
46
+ #
47
+ # @param n_clusters [Integer] The number of clusters.
48
+ # @param init [String] The initialization method for centroids ('random' or 'k-means++').
49
+ # @param covariance_type [String] The type of covariance parameter to be used ('diag' or 'full').
50
+ # @param max_iter [Integer] The maximum number of iterations.
51
+ # @param tol [Float] The tolerance of termination criterion.
52
+ # @param reg_covar [Float] The non-negative regularization to the diagonal of covariance.
53
+ # @param random_seed [Integer] The seed value using to initialize the random generator.
54
+ def initialize(n_clusters: 8, init: 'k-means++', covariance_type: 'diag',
55
+ max_iter: 50, tol: 1.0e-4, reg_covar: 1.0e-6, random_seed: nil)
56
+ super()
57
+ @params = {
58
+ n_clusters: n_clusters,
59
+ init: (init == 'random' ? 'random' : 'k-means++'),
60
+ covariance_type: (covariance_type == 'full' ? 'full' : 'diag'),
61
+ max_iter: max_iter,
62
+ tol: tol,
63
+ reg_covar: reg_covar,
64
+ random_seed: random_seed || srand
65
+ }
66
+ end
67
+
68
+ # Analysis clusters with given training data.
69
+ #
70
+ # @overload fit(x) -> GaussianMixture
71
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
72
+ # @return [GaussianMixture] The learned cluster analyzer itself.
73
+ def fit(x, _y = nil)
74
+ check_enable_linalg('fit')
75
+ x = ::Rumale::Validation.check_convert_sample_array(x)
76
+
77
+ n_samples = x.shape[0]
78
+ memberships = init_memberships(x)
79
+ @params[:max_iter].times do |t|
80
+ @n_iter = t
81
+ @weights = calc_weights(n_samples, memberships)
82
+ @means = calc_means(x, memberships)
83
+ @covariances = calc_covariances(x, @means, memberships, @params[:reg_covar], @params[:covariance_type])
84
+ new_memberships = calc_memberships(x, @weights, @means, @covariances, @params[:covariance_type])
85
+ error = (memberships - new_memberships).abs.max
86
+ break if error <= @params[:tol]
87
+
88
+ memberships = new_memberships.dup
89
+ end
90
+ self
91
+ end
92
+
93
+ # Predict cluster labels for samples.
94
+ #
95
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the cluster label.
96
+ # @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
97
+ def predict(x)
98
+ check_enable_linalg('predict')
99
+ x = ::Rumale::Validation.check_convert_sample_array(x)
100
+
101
+ memberships = calc_memberships(x, @weights, @means, @covariances, @params[:covariance_type])
102
+ assign_cluster(memberships)
103
+ end
104
+
105
+ # Analysis clusters and assign samples to clusters.
106
+ #
107
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
108
+ # @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
109
+ def fit_predict(x)
110
+ check_enable_linalg('fit_predict')
111
+ x = ::Rumale::Validation.check_convert_sample_array(x)
112
+
113
+ fit(x).predict(x)
114
+ end
115
+
116
+ private
117
+
118
+ def assign_cluster(memberships)
119
+ n_clusters = memberships.shape[1]
120
+ memberships.max_index(axis: 1) - Numo::Int32[*0.step(memberships.size - 1, n_clusters)]
121
+ end
122
+
123
+ def init_memberships(x)
124
+ kmeans = ::Rumale::Clustering::KMeans.new(
125
+ n_clusters: @params[:n_clusters], init: @params[:init], max_iter: 0, random_seed: @params[:random_seed]
126
+ )
127
+ cluster_ids = kmeans.fit_predict(x)
128
+ Numo::DFloat.cast(::Rumale::Utils.binarize_labels(cluster_ids))
129
+ end
130
+
131
+ def calc_memberships(x, weights, means, covars, covar_type)
132
+ n_samples = x.shape[0]
133
+ n_clusters = means.shape[0]
134
+ memberships = Numo::DFloat.zeros(n_samples, n_clusters)
135
+ n_clusters.times do |n|
136
+ centered = x - means[n, true]
137
+ covar = covar_type == 'full' ? covars[n, true, true] : covars[n, true]
138
+ memberships[true, n] = calc_unnormalized_membership(centered, weights[n], covar, covar_type)
139
+ end
140
+ memberships / memberships.sum(axis: 1).expand_dims(1)
141
+ end
142
+
143
+ def calc_weights(n_samples, memberships)
144
+ memberships.sum(axis: 0) / n_samples
145
+ end
146
+
147
+ def calc_means(x, memberships)
148
+ memberships.transpose.dot(x) / memberships.sum(axis: 0).expand_dims(1)
149
+ end
150
+
151
+ def calc_covariances(x, means, memberships, reg_cover, covar_type)
152
+ if covar_type == 'full'
153
+ calc_full_covariances(x, means, reg_cover, memberships)
154
+ else
155
+ calc_diag_covariances(x, means, reg_cover, memberships)
156
+ end
157
+ end
158
+
159
+ def calc_diag_covariances(x, means, reg_cover, memberships)
160
+ n_clusters = means.shape[0]
161
+ diag_cov = Array.new(n_clusters) do |n|
162
+ centered = x - means[n, true]
163
+ memberships[true, n].dot(centered**2) / memberships[true, n].sum
164
+ end
165
+ Numo::DFloat.asarray(diag_cov) + reg_cover
166
+ end
167
+
168
+ def calc_full_covariances(x, means, reg_cover, memberships)
169
+ n_features = x.shape[1]
170
+ n_clusters = means.shape[0]
171
+ cov_mats = Numo::DFloat.zeros(n_clusters, n_features, n_features)
172
+ reg_mat = Numo::DFloat.eye(n_features) * reg_cover
173
+ n_clusters.times do |n|
174
+ centered = x - means[n, true]
175
+ members = memberships[true, n]
176
+ cov_mats[n, true, true] = reg_mat + (centered.transpose * members).dot(centered) / members.sum
177
+ end
178
+ cov_mats
179
+ end
180
+
181
+ def calc_unnormalized_membership(centered, weight, covar, covar_type)
182
+ inv_covar = calc_inv_covariance(covar, covar_type)
183
+ inv_sqrt_det_covar = calc_inv_sqrt_det_covariance(covar, covar_type)
184
+ distances = if covar_type == 'full'
185
+ (centered.dot(inv_covar) * centered).sum(axis: 1)
186
+ else
187
+ (centered * inv_covar * centered).sum(axis: 1)
188
+ end
189
+ weight * inv_sqrt_det_covar * Numo::NMath.exp(-0.5 * distances)
190
+ end
191
+
192
+ def calc_inv_covariance(covar, covar_type)
193
+ if covar_type == 'full'
194
+ Numo::Linalg.inv(covar)
195
+ else
196
+ 1.0 / covar
197
+ end
198
+ end
199
+
200
+ def calc_inv_sqrt_det_covariance(covar, covar_type)
201
+ if covar_type == 'full'
202
+ 1.0 / Math.sqrt(Numo::Linalg.det(covar))
203
+ else
204
+ 1.0 / Math.sqrt(covar.prod)
205
+ end
206
+ end
207
+
208
+ def check_enable_linalg(method_name)
209
+ return unless @params[:covariance_type] == 'full' && !enable_linalg?
210
+
211
+ raise "GaussianMixture##{method_name} requires Numo::Linalg when covariance_type is 'full' but that is not loaded."
212
+ end
213
+ end
214
+ end
215
+ end