rumale-clustering 0.24.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,289 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/estimator'
4
+ require 'rumale/base/cluster_analyzer'
5
+ require 'rumale/pairwise_metric'
6
+ require 'rumale/validation'
7
+ require 'rumale/clustering/single_linkage'
8
+
9
+ module Rumale
10
+ module Clustering
11
+ # HDBSCAN is a class that implements HDBSCAN cluster analysis.
12
+ #
13
+ # @example
14
+ # require 'rumale/clustering/hdbscan'
15
+ #
16
+ # analyzer = Rumale::Clustering::HDBSCAN.new(min_samples: 5)
17
+ # cluster_labels = analyzer.fit_predict(samples)
18
+ #
19
+ # *Reference*
20
+ # - Campello, R J. G. B., Moulavi, D., Zimek, A., and Sander, J., "Hierarchical Density Estimates for Data Clustering, Visualization, and Outlier Detection," TKDD, Vol. 10 (1), pp. 5:1--5:51, 2015.
21
+ # - Campello, R J. G. B., Moulavi, D., and Sander, J., "Density-Based Clustering Based on Hierarchical Density Estimates," Proc. PAKDD'13, pp. 160--172, 2013.
22
+ # - Lelis, L., and Sander, J., "Semi-Supervised Density-Based Clustering," Proc. ICDM'09, pp. 842--847, 2009.
23
+ class HDBSCAN < ::Rumale::Base::Estimator # rubocop:disable Metrics/ClassLength
24
+ include ::Rumale::Base::ClusterAnalyzer
25
+
26
+ # Return the cluster labels. The negative cluster label indicates that the point is noise.
27
+ # @return [Numo::Int32] (shape: [n_samples])
28
+ attr_reader :labels
29
+
30
+ # Create a new cluster analyzer with HDBSCAN algorithm.
31
+ #
32
+ # @param min_samples [Integer] The number of neighbor samples to be used for the criterion whether a point is a core point.
33
+ # @param min_cluster_size [Integer/Nil] The minimum size of cluster. If nil is given, it is set equal to min_samples.
34
+ # @param metric [String] The metric to calculate the distances.
35
+ # If metric is 'euclidean', Euclidean distance is calculated for distance between points.
36
+ # If metric is 'precomputed', the fit and fit_transform methods expect to be given a distance matrix.
37
+ def initialize(min_samples: 10, min_cluster_size: nil, metric: 'euclidean')
38
+ super()
39
+ @params = {
40
+ min_samples: min_samples,
41
+ min_cluster_size: min_cluster_size || min_samples,
42
+ metric: (metric == 'precomputed' ? 'precomputed' : 'euclidean')
43
+ }
44
+ end
45
+
46
+ # Analysis clusters with given training data.
47
+ #
48
+ # @overload fit(x) -> HDBSCAN
49
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
50
+ # If the metric is 'precomputed', x must be a square distance matrix (shape: [n_samples, n_samples]).
51
+ # @return [HDBSCAN] The learned cluster analyzer itself.
52
+ def fit(x, _y = nil)
53
+ x = ::Rumale::Validation.check_convert_sample_array(x)
54
+ raise ArgumentError, 'the input distance matrix should be square' if check_invalid_array_shape(x)
55
+
56
+ fit_predict(x)
57
+ self
58
+ end
59
+
60
+ # Analysis clusters and assign samples to clusters.
61
+ #
62
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to be used for cluster analysis.
63
+ # If the metric is 'precomputed', x must be a square distance matrix (shape: [n_samples, n_samples]).
64
+ # @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
65
+ def fit_predict(x)
66
+ x = ::Rumale::Validation.check_convert_sample_array(x)
67
+ raise ArgumentError, 'the input distance matrix should be square' if check_invalid_array_shape(x)
68
+
69
+ distance_mat = @params[:metric] == 'precomputed' ? x : ::Rumale::PairwiseMetric.euclidean_distance(x)
70
+ @labels = partial_fit(distance_mat)
71
+ end
72
+
73
+ private
74
+
75
+ def check_invalid_array_shape(x)
76
+ @params[:metric] == 'precomputed' && x.shape[0] != x.shape[1]
77
+ end
78
+
79
+ # @!visibility private
80
+ class UnionFind
81
+ def initialize(n)
82
+ @parent = Numo::Int32.new(n).seq
83
+ @rank = Numo::Int32.zeros(n)
84
+ end
85
+
86
+ # @!visibility private
87
+ def union(x, y)
88
+ x_root = find(x)
89
+ y_root = find(y)
90
+
91
+ return if x_root == y_root
92
+
93
+ if @rank[x_root] < @rank[y_root]
94
+ @parent[x_root] = y_root
95
+ else
96
+ @parent[y_root] = x_root
97
+ @rank[x_root] += 1 if @rank[x_root] == @rank[y_root]
98
+ end
99
+
100
+ nil
101
+ end
102
+
103
+ # @!visibility private
104
+ def find(x)
105
+ @parent[x] = find(@parent[x]) if @parent[x] != x
106
+ @parent[x]
107
+ end
108
+ end
109
+
110
+ # @!visibility private
111
+ class Node
112
+ # @!visibility private
113
+ attr_reader :x, :y, :weight, :n_elements
114
+
115
+ # @!visibility private
116
+ def initialize(x:, y:, weight:, n_elements: 0)
117
+ @x = x
118
+ @y = y
119
+ @weight = weight
120
+ @n_elements = n_elements
121
+ end
122
+
123
+ # @!visibility private
124
+ def ==(other)
125
+ x == other.x && y == other.y && weight == other.weight && n_elements == other.n_elements
126
+ end
127
+ end
128
+
129
+ private_constant :UnionFind, :Node
130
+
131
+ def partial_fit(distance_mat)
132
+ mr_distance_mat = mutual_reachability_distances(distance_mat, @params[:min_samples])
133
+ hierarchy = ::Rumale::Clustering::SingleLinkage.new(n_clusters: 1, metric: 'precomputed').fit(mr_distance_mat).hierarchy
134
+ tree = condense_tree(hierarchy, @params[:min_cluster_size])
135
+ stabilities = cluster_stability(tree)
136
+ flatten(tree, stabilities)
137
+ end
138
+
139
+ def mutual_reachability_distances(distance_mat, min_samples)
140
+ core_distances = distance_mat.sort(axis: 1)[true, min_samples + 1]
141
+ Numo::DFloat.maximum(core_distances.expand_dims(1), Numo::DFloat.maximum(core_distances, distance_mat))
142
+ end
143
+
144
+ def breadth_first_search_hierarchy(hierarchy, root)
145
+ n_edges = hierarchy.size
146
+ n_points = n_edges + 1
147
+ to_process = [root]
148
+ res = []
149
+ while to_process.any?
150
+ res.concat(to_process)
151
+ to_process = to_process.select { |n| n >= n_points }.map { |n| n - n_points }
152
+ to_process = to_process.map { |n| [hierarchy[n].x, hierarchy[n].y] }.flatten if to_process.any?
153
+ end
154
+ res
155
+ end
156
+
157
+ # rubocop:disable Metrics/AbcSize, Metrics/CyclomaticComplexity, Metrics/MethodLength, Metrics/PerceivedComplexity
158
+ def condense_tree(hierarchy, min_cluster_size)
159
+ n_edges = hierarchy.size
160
+ root = 2 * n_edges
161
+ n_points = n_edges + 1
162
+ next_label = n_points + 1
163
+
164
+ node_ids = breadth_first_search_hierarchy(hierarchy, root)
165
+
166
+ relabel = Numo::Int32.zeros(root + 1)
167
+ relabel[root] = n_points
168
+ res = []
169
+ visited = {}
170
+
171
+ node_ids.each do |n_id|
172
+ next if visited[n_id] || n_id < n_points
173
+
174
+ edge = hierarchy[n_id - n_points]
175
+
176
+ density = edge.weight > 0.0 ? 1.fdiv(edge.weight) : Float::INFINITY
177
+ n_x_elements = edge.x >= n_points ? hierarchy[edge.x - n_points].n_elements : 1
178
+ n_y_elements = edge.y >= n_points ? hierarchy[edge.y - n_points].n_elements : 1
179
+
180
+ if n_x_elements >= min_cluster_size && n_y_elements >= min_cluster_size
181
+ relabel[edge.x] = next_label
182
+ res.push(Node.new(x: relabel[n_id], y: relabel[edge.x], weight: density, n_elements: n_x_elements))
183
+ next_label += 1
184
+ relabel[edge.y] = next_label
185
+ res.push(Node.new(x: relabel[n_id], y: relabel[edge.y], weight: density, n_elements: n_y_elements))
186
+ next_label += 1
187
+ elsif n_x_elements < min_cluster_size && n_y_elements < min_cluster_size
188
+ breadth_first_search_hierarchy(hierarchy, edge.x).each do |sn_id|
189
+ res.push(Node.new(x: relabel[n_id], y: sn_id, weight: density, n_elements: 1)) if sn_id < n_points
190
+ visited[sn_id] = true
191
+ end
192
+ breadth_first_search_hierarchy(hierarchy, edge.y).each do |sn_id|
193
+ res.push(Node.new(x: relabel[n_id], y: sn_id, weight: density, n_elements: 1)) if sn_id < n_points
194
+ visited[sn_id] = true
195
+ end
196
+ elsif n_x_elements < min_cluster_size
197
+ relabel[edge.y] = relabel[n_id]
198
+ breadth_first_search_hierarchy(hierarchy, edge.x).each do |sn_id|
199
+ res.push(Node.new(x: relabel[n_id], y: sn_id, weight: density, n_elements: 1)) if sn_id < n_points
200
+ visited[sn_id] = true
201
+ end
202
+ elsif n_y_elements < min_cluster_size
203
+ relabel[edge.x] = relabel[n_id]
204
+ breadth_first_search_hierarchy(hierarchy, edge.y).each do |sn_id|
205
+ res.push(Node.new(x: relabel[n_id], y: sn_id, weight: density, n_elements: 1)) if sn_id < n_points
206
+ visited[sn_id] = true
207
+ end
208
+ end
209
+ end
210
+ res
211
+ end
212
+
213
+ def cluster_stability(tree)
214
+ tree.sort! { |a, b| a.weight <=> b.weight }
215
+
216
+ root = tree.map(&:x).min
217
+ child_max = tree.map(&:y).max
218
+ child_max = root if child_max < root
219
+ densities = Numo::DFloat.zeros(child_max + 1) + Float::INFINITY
220
+
221
+ current = tree[0].y
222
+ density_min = tree[0].weight
223
+ tree.each do |edge|
224
+ if edge.x == current
225
+ density_min = [density_min, edge.weight].min
226
+ else
227
+ densities[current] = density_min
228
+ current = edge.y
229
+ density_min = edge.weight
230
+ end
231
+ end
232
+
233
+ densities[current] = density_min if current != tree[0].y
234
+ densities[root] = 0.0
235
+
236
+ tree.each_with_object({}) do |edge, stab|
237
+ stab[edge.x] ||= 0.0
238
+ stab[edge.x] += (edge.weight - densities[edge.x]) * edge.n_elements
239
+ end
240
+ end
241
+
242
+ def breadth_first_search_tree(tree, root)
243
+ to_process = [root]
244
+ res = []
245
+ while to_process.any?
246
+ res.concat(to_process)
247
+ to_process = tree.select { |v| to_process.include?(v.x) }.map(&:y)
248
+ end
249
+ res
250
+ end
251
+
252
+ def flatten(tree, stabilities)
253
+ node_ids = stabilities.keys.sort.reverse.slice(0, stabilities.size - 1)
254
+
255
+ cluster_tree = tree.select { |edge| edge.n_elements > 1 }
256
+ is_cluster = node_ids.each_with_object({}) { |n_id, h| h[n_id] = true }
257
+
258
+ node_ids.each do |n_id|
259
+ children = cluster_tree.select { |node| node.x == n_id }.map(&:y)
260
+ subtree_stability = children.inject(0.0) { |sum, c_id| sum + stabilities[c_id] }
261
+ if subtree_stability > stabilities[n_id]
262
+ is_cluster[n_id] = false
263
+ stabilities[n_id] = subtree_stability
264
+ else
265
+ breadth_first_search_tree(cluster_tree, n_id).each do |sn_id|
266
+ is_cluster[sn_id] = false if sn_id != n_id
267
+ end
268
+ end
269
+ end
270
+
271
+ selected_node_ids = is_cluster.select { |_k, v| v == true }.keys.uniq.sort
272
+ cluster_label_map = selected_node_ids.each_with_object({}).with_index { |(n_idx, h), c_idx| h[n_idx] = c_idx }
273
+
274
+ parent_arr = tree.map(&:x)
275
+ uf = UnionFind.new(parent_arr.max + 1)
276
+ tree.each { |edge| uf.union(edge.x, edge.y) if cluster_label_map[edge.y].nil? }
277
+
278
+ root = parent_arr.min
279
+ res = Numo::Int32.zeros(root)
280
+ root.times do |n|
281
+ cluster = uf.find(n)
282
+ res[n] = cluster < root ? -1 : cluster_label_map[cluster] || -1
283
+ end
284
+ res
285
+ end
286
+ # rubocop:enable Metrics/AbcSize, Metrics/CyclomaticComplexity, Metrics/MethodLength, Metrics/PerceivedComplexity
287
+ end
288
+ end
289
+ end
@@ -0,0 +1,120 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/estimator'
4
+ require 'rumale/base/cluster_analyzer'
5
+ require 'rumale/pairwise_metric'
6
+ require 'rumale/validation'
7
+
8
+ module Rumale
9
+ module Clustering
10
+ # KMeans is a class that implements K-Means cluster analysis.
11
+ # The current implementation uses the Euclidean distance for analyzing the clusters.
12
+ #
13
+ # @example
14
+ # require 'rumale/clustering/k_means'
15
+ #
16
+ # analyzer = Rumale::Clustering::KMeans.new(n_clusters: 10, max_iter: 50)
17
+ # cluster_labels = analyzer.fit_predict(samples)
18
+ #
19
+ # *Reference*
20
+ # - Arthur, D., and Vassilvitskii, S., "k-means++: the advantages of careful seeding," Proc. SODA'07, pp. 1027--1035, 2007.
21
+ class KMeans < ::Rumale::Base::Estimator
22
+ include ::Rumale::Base::ClusterAnalyzer
23
+
24
+ # Return the centroids.
25
+ # @return [Numo::DFloat] (shape: [n_clusters, n_features])
26
+ attr_reader :cluster_centers
27
+
28
+ # Return the random generator.
29
+ # @return [Random]
30
+ attr_reader :rng
31
+
32
+ # Create a new cluster analyzer with K-Means method.
33
+ #
34
+ # @param n_clusters [Integer] The number of clusters.
35
+ # @param init [String] The initialization method for centroids ('random' or 'k-means++').
36
+ # @param max_iter [Integer] The maximum number of iterations.
37
+ # @param tol [Float] The tolerance of termination criterion.
38
+ # @param random_seed [Integer] The seed value using to initialize the random generator.
39
+ def initialize(n_clusters: 8, init: 'k-means++', max_iter: 50, tol: 1.0e-4, random_seed: nil)
40
+ super()
41
+ @params = {
42
+ n_clusters: n_clusters,
43
+ init: (init == 'random' ? 'random' : 'k-means++'),
44
+ max_iter: max_iter,
45
+ tol: tol,
46
+ random_seed: (random_seed || srand)
47
+ }
48
+ @rng = Random.new(@params[:random_seed])
49
+ end
50
+
51
+ # Analysis clusters with given training data.
52
+ #
53
+ # @overload fit(x) -> KMeans
54
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
55
+ # @return [KMeans] The learned cluster analyzer itself.
56
+ def fit(x, _y = nil)
57
+ x = ::Rumale::Validation.check_convert_sample_array(x)
58
+
59
+ init_cluster_centers(x)
60
+ @params[:max_iter].times do |_t|
61
+ cluster_labels = assign_cluster(x)
62
+ old_centers = @cluster_centers.dup
63
+ @params[:n_clusters].times do |n|
64
+ assigned_bits = cluster_labels.eq(n)
65
+ @cluster_centers[n, true] = x[assigned_bits.where, true].mean(axis: 0) if assigned_bits.count.positive?
66
+ end
67
+ error = Numo::NMath.sqrt(((old_centers - @cluster_centers)**2).sum(axis: 1)).mean
68
+ break if error <= @params[:tol]
69
+ end
70
+ self
71
+ end
72
+
73
+ # Predict cluster labels for samples.
74
+ #
75
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the cluster label.
76
+ # @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
77
+ def predict(x)
78
+ x = ::Rumale::Validation.check_convert_sample_array(x)
79
+
80
+ assign_cluster(x)
81
+ end
82
+
83
+ # Analysis clusters and assign samples to clusters.
84
+ #
85
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
86
+ # @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
87
+ def fit_predict(x)
88
+ x = ::Rumale::Validation.check_convert_sample_array(x)
89
+
90
+ fit(x).predict(x)
91
+ end
92
+
93
+ private
94
+
95
+ def assign_cluster(x)
96
+ distance_matrix = ::Rumale::PairwiseMetric.euclidean_distance(x, @cluster_centers)
97
+ distance_matrix.min_index(axis: 1) - Numo::Int32[*0.step(distance_matrix.size - 1, @cluster_centers.shape[0])]
98
+ end
99
+
100
+ def init_cluster_centers(x)
101
+ # random initialize
102
+ n_samples = x.shape[0]
103
+ sub_rng = @rng.dup
104
+ rand_id = Array(0...n_samples).sample(@params[:n_clusters], random: sub_rng)
105
+ @cluster_centers = x[rand_id, true].dup
106
+ return unless @params[:init] == 'k-means++'
107
+
108
+ # k-means++ initialize
109
+ (1...@params[:n_clusters]).each do |n|
110
+ distance_matrix = ::Rumale::PairwiseMetric.euclidean_distance(x, @cluster_centers[0...n, true])
111
+ min_distances = distance_matrix.flatten[distance_matrix.min_index(axis: 1)]
112
+ probs = min_distances**2 / (min_distances**2).sum
113
+ cum_probs = probs.cumsum
114
+ selected_id = cum_probs.gt(sub_rng.rand).where.to_a.first
115
+ @cluster_centers[n, true] = x[selected_id, true].dup
116
+ end
117
+ end
118
+ end
119
+ end
120
+ end
@@ -0,0 +1,143 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/estimator'
4
+ require 'rumale/base/cluster_analyzer'
5
+ require 'rumale/pairwise_metric'
6
+
7
+ module Rumale
8
+ module Clustering
9
+ # KMedoids is a class that implements K-Medoids cluster analysis.
10
+ #
11
+ # @example
12
+ # require 'rumale/clustering/k_medoids'
13
+ #
14
+ # analyzer = Rumale::Clustering::KMedoids.new(n_clusters: 10, max_iter: 50)
15
+ # cluster_labels = analyzer.fit_predict(samples)
16
+ #
17
+ # *Reference*
18
+ # - Arthur, D., and Vassilvitskii, S., "k-means++: the advantages of careful seeding," Proc. SODA'07, pp. 1027--1035, 2007.
19
+ class KMedoids < ::Rumale::Base::Estimator
20
+ include ::Rumale::Base::ClusterAnalyzer
21
+
22
+ # Return the indices of medoids.
23
+ # @return [Numo::Int32] (shape: [n_clusters])
24
+ attr_reader :medoid_ids
25
+
26
+ # Return the random generator.
27
+ # @return [Random]
28
+ attr_reader :rng
29
+
30
+ # Create a new cluster analyzer with K-Medoids method.
31
+ #
32
+ # @param n_clusters [Integer] The number of clusters.
33
+ # @param metric [String] The metric to calculate the distances.
34
+ # If metric is 'euclidean', Euclidean distance is calculated for distance between points.
35
+ # If metric is 'precomputed', the fit and fit_transform methods expect to be given a distance matrix.
36
+ # @param init [String] The initialization method for centroids ('random' or 'k-means++').
37
+ # @param max_iter [Integer] The maximum number of iterations.
38
+ # @param tol [Float] The tolerance of termination criterion.
39
+ # @param random_seed [Integer] The seed value using to initialize the random generator.
40
+ def initialize(n_clusters: 8, metric: 'euclidean', init: 'k-means++', max_iter: 50, tol: 1.0e-4, random_seed: nil)
41
+ super()
42
+ @params = {
43
+ n_clusters: n_clusters,
44
+ metric: (metric == 'precomputed' ? 'precomputed' : 'euclidean'),
45
+ init: (init == 'random' ? 'random' : 'k-means++'),
46
+ max_iter: max_iter,
47
+ tol: tol,
48
+ random_seed: (random_seed || srand)
49
+ }
50
+ @rng = Random.new(@params[:random_seed])
51
+ end
52
+
53
+ # Analysis clusters with given training data.
54
+ #
55
+ # @overload fit(x) -> KMedoids
56
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
57
+ # If the metric is 'precomputed', x must be a square distance matrix (shape: [n_samples, n_samples]).
58
+ # @return [KMedoids] The learned cluster analyzer itself.
59
+ def fit(x, _y = nil)
60
+ x = ::Rumale::Validation.check_convert_sample_array(x)
61
+ raise ArgumentError, 'the input distance matrix should be square' if check_invalid_array_shape(x)
62
+
63
+ # initialize some varibales.
64
+ distance_mat = @params[:metric] == 'precomputed' ? x : ::Rumale::PairwiseMetric.euclidean_distance(x)
65
+ init_cluster_centers(distance_mat)
66
+ error = distance_mat[true, @medoid_ids].mean
67
+ @params[:max_iter].times do |_t|
68
+ cluster_labels = assign_cluster(distance_mat[true, @medoid_ids])
69
+ @params[:n_clusters].times do |n|
70
+ assigned_ids = cluster_labels.eq(n).where
71
+ @medoid_ids[n] = assigned_ids[distance_mat[assigned_ids, assigned_ids].sum(axis: 1).min_index]
72
+ end
73
+ new_error = distance_mat[true, @medoid_ids].mean
74
+ break if (error - new_error).abs <= @params[:tol]
75
+
76
+ error = new_error
77
+ end
78
+ @cluster_centers = x[@medoid_ids, true].dup if @params[:metric] == 'euclidean'
79
+ self
80
+ end
81
+
82
+ # Predict cluster labels for samples.
83
+ #
84
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the cluster label.
85
+ # If the metric is 'precomputed', x must be distances between samples and medoids (shape: [n_samples, n_clusters]).
86
+ # @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
87
+ def predict(x)
88
+ x = ::Rumale::Validation.check_convert_sample_array(x)
89
+
90
+ distance_mat = @params[:metric] == 'precomputed' ? x : ::Rumale::PairwiseMetric.euclidean_distance(x, @cluster_centers)
91
+ if @params[:metric] == 'precomputed' && distance_mat.shape[1] != @medoid_ids.size
92
+ raise ArgumentError, 'the shape of input matrix should be n_samples-by-n_clusters'
93
+ end
94
+
95
+ assign_cluster(distance_mat)
96
+ end
97
+
98
+ # Analysis clusters and assign samples to clusters.
99
+ #
100
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
101
+ # If the metric is 'precomputed', x must be a square distance matrix (shape: [n_samples, n_samples]).
102
+ # @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
103
+ def fit_predict(x)
104
+ x = ::Rumale::Validation.check_convert_sample_array(x)
105
+ raise ArgumentError, 'the input distance matrix should be square' if check_invalid_array_shape(x)
106
+
107
+ fit(x)
108
+ if @params[:metric] == 'precomputed'
109
+ predict(x[true, @medoid_ids])
110
+ else
111
+ predict(x)
112
+ end
113
+ end
114
+
115
+ private
116
+
117
+ def check_invalid_array_shape(x)
118
+ @params[:metric] == 'precomputed' && x.shape[0] != x.shape[1]
119
+ end
120
+
121
+ def assign_cluster(distances_to_medoids)
122
+ distances_to_medoids.min_index(axis: 1) - Numo::Int32[*0.step(distances_to_medoids.size - 1, @params[:n_clusters])]
123
+ end
124
+
125
+ def init_cluster_centers(distance_mat)
126
+ # random initialize
127
+ n_samples = distance_mat.shape[0]
128
+ sub_rng = @rng.dup
129
+ @medoid_ids = Numo::Int32.asarray(Array(0...n_samples).sample(@params[:n_clusters], random: sub_rng))
130
+ return unless @params[:init] == 'k-means++'
131
+
132
+ # k-means++ initialize
133
+ (1...@params[:n_clusters]).each do |n|
134
+ distances = distance_mat[true, @medoid_ids[0...n]]
135
+ min_distances = distances.flatten[distances.min_index(axis: 1)]
136
+ probs = min_distances**2 / (min_distances**2).sum
137
+ cum_probs = probs.cumsum
138
+ @medoid_ids[n] = cum_probs.gt(sub_rng.rand).where.to_a.first
139
+ end
140
+ end
141
+ end
142
+ end
143
+ end
@@ -0,0 +1,138 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/estimator'
4
+ require 'rumale/base/cluster_analyzer'
5
+ require 'rumale/pairwise_metric'
6
+ require 'rumale/validation'
7
+
8
+ module Rumale
9
+ module Clustering
10
+ # MniBatchKMeans is a class that implements K-Means cluster analysis
11
+ # with mini-batch stochastic gradient descent (SGD).
12
+ #
13
+ # @example
14
+ # require 'rumale/clustering/mini_batch_k_means'
15
+ #
16
+ # analyzer = Rumale::Clustering::MiniBatchKMeans.new(n_clusters: 10, max_iter: 50, batch_size: 50, random_seed: 1)
17
+ # cluster_labels = analyzer.fit_predict(samples)
18
+ #
19
+ # *Reference*
20
+ # - Sculley, D., "Web-scale k-means clustering," Proc. WWW'10, pp. 1177--1178, 2010.
21
+ class MiniBatchKMeans < ::Rumale::Base::Estimator
22
+ include ::Rumale::Base::ClusterAnalyzer
23
+
24
+ # Return the centroids.
25
+ # @return [Numo::DFloat] (shape: [n_clusters, n_features])
26
+ attr_reader :cluster_centers
27
+
28
+ # Return the random generator.
29
+ # @return [Random]
30
+ attr_reader :rng
31
+
32
+ # Create a new cluster analyzer with K-Means method with mini-batch SGD.
33
+ #
34
+ # @param n_clusters [Integer] The number of clusters.
35
+ # @param init [String] The initialization method for centroids ('random' or 'k-means++').
36
+ # @param max_iter [Integer] The maximum number of iterations.
37
+ # @param batch_size [Integer] The size of the mini batches.
38
+ # @param tol [Float] The tolerance of termination criterion.
39
+ # @param random_seed [Integer] The seed value using to initialize the random generator.
40
+ def initialize(n_clusters: 8, init: 'k-means++', max_iter: 100, batch_size: 100, tol: 1.0e-4, random_seed: nil)
41
+ super()
42
+ @params = {
43
+ n_clusters: n_clusters,
44
+ init: (init == 'random' ? 'random' : 'k-means++'),
45
+ max_iter: max_iter,
46
+ batch_size: batch_size,
47
+ tol: tol,
48
+ random_seed: (random_seed || srand)
49
+ }
50
+ @rng = Random.new(@params[:random_seed])
51
+ end
52
+
53
+ # Analysis clusters with given training data.
54
+ #
55
+ # @overload fit(x) -> MiniBatchKMeans
56
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
57
+ # @return [KMeans] The learned cluster analyzer itself.
58
+ def fit(x, _y = nil)
59
+ x = ::Rumale::Validation.check_convert_sample_array(x)
60
+
61
+ # initialization.
62
+ n_samples = x.shape[0]
63
+ update_counter = Numo::Int32.zeros(@params[:n_clusters])
64
+ sub_rng = @rng.dup
65
+ init_cluster_centers(x, sub_rng)
66
+ # optimization with mini-batch sgd.
67
+ @params[:max_iter].times do |_t|
68
+ sample_ids = Array(0...n_samples).shuffle(random: sub_rng)
69
+ old_centers = @cluster_centers.dup
70
+ until (subset_ids = sample_ids.shift(@params[:batch_size])).empty?
71
+ # sub sampling
72
+ sub_x = x[subset_ids, true]
73
+ # assign nearest centroids
74
+ cluster_labels = assign_cluster(sub_x)
75
+ # update centroids
76
+ @params[:n_clusters].times do |c|
77
+ assigned_bits = cluster_labels.eq(c)
78
+ next unless assigned_bits.count.positive?
79
+
80
+ update_counter[c] += 1
81
+ learning_rate = 1.fdiv(update_counter[c])
82
+ update = sub_x[assigned_bits.where, true].mean(axis: 0)
83
+ @cluster_centers[c, true] = (1 - learning_rate) * @cluster_centers[c, true] + learning_rate * update
84
+ end
85
+ end
86
+ error = Numo::NMath.sqrt(((old_centers - @cluster_centers)**2).sum(axis: 1)).mean
87
+ break if error <= @params[:tol]
88
+ end
89
+ self
90
+ end
91
+
92
+ # Predict cluster labels for samples.
93
+ #
94
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the cluster label.
95
+ # @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
96
+ def predict(x)
97
+ x = ::Rumale::Validation.check_convert_sample_array(x)
98
+
99
+ assign_cluster(x)
100
+ end
101
+
102
+ # Analysis clusters and assign samples to clusters.
103
+ #
104
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
105
+ # @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
106
+ def fit_predict(x)
107
+ x = ::Rumale::Validation.check_convert_sample_array(x)
108
+
109
+ fit(x).predict(x)
110
+ end
111
+
112
+ private
113
+
114
+ def assign_cluster(x)
115
+ distance_matrix = ::Rumale::PairwiseMetric.euclidean_distance(x, @cluster_centers)
116
+ distance_matrix.min_index(axis: 1) - Numo::Int32[*0.step(distance_matrix.size - 1, @cluster_centers.shape[0])]
117
+ end
118
+
119
+ def init_cluster_centers(x, sub_rng)
120
+ # random initialize
121
+ n_samples = x.shape[0]
122
+ rand_id = Array(0...n_samples).sample(@params[:n_clusters], random: sub_rng)
123
+ @cluster_centers = x[rand_id, true].dup
124
+ return unless @params[:init] == 'k-means++'
125
+
126
+ # k-means++ initialize
127
+ (1...@params[:n_clusters]).each do |n|
128
+ distance_matrix = ::Rumale::PairwiseMetric.euclidean_distance(x, @cluster_centers[0...n, true])
129
+ min_distances = distance_matrix.flatten[distance_matrix.min_index(axis: 1)]
130
+ probs = min_distances**2 / (min_distances**2).sum
131
+ cum_probs = probs.cumsum
132
+ selected_id = cum_probs.gt(sub_rng.rand).where.to_a.first
133
+ @cluster_centers[n, true] = x[selected_id, true].dup
134
+ end
135
+ end
136
+ end
137
+ end
138
+ end