rumale-clustering 0.25.0 → 0.26.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/lib/rumale/clustering/mean_shift.rb +116 -0
- data/lib/rumale/clustering/version.rb +1 -1
- data/lib/rumale/clustering.rb +1 -0
- metadata +5 -4
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: d4bd3dd2f04f44e145f7ebe90154dfd9e2970485c58cd68af82abd0f744e78c0
|
4
|
+
data.tar.gz: 9af9eafaa42f596f75c09a13b6da658bf7ceb69af657a22bc50f6c58b7a2eb05
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 2bcbe3e94d4ae65507fb6253b68264dcc42ed455a7150d7096ea71d4d838dae2de2ba2b7991e06b258b16e28f701f5c350b5e26a7c06f43a0b08d6e0145c5cb1
|
7
|
+
data.tar.gz: c08a182fd31b16aaad51186c4dfe3457e5ea7d5f9b956a966ead434a036d9817189cd3e5ee7be824a04c06baf60771423f9ef23c0488aceb1705dfde7aaeb8a4
|
@@ -0,0 +1,116 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'rumale/base/estimator'
|
4
|
+
require 'rumale/base/cluster_analyzer'
|
5
|
+
require 'rumale/pairwise_metric'
|
6
|
+
require 'rumale/validation'
|
7
|
+
|
8
|
+
module Rumale
|
9
|
+
module Clustering
|
10
|
+
# MeanShift is a class that implements mean-shift clustering with flat kernel.
|
11
|
+
#
|
12
|
+
# @example
|
13
|
+
# require 'rumale/clustering/mean_shift'
|
14
|
+
#
|
15
|
+
# analyzer = Rumale::Clustering::MeanShift.new(bandwidth: 1.5)
|
16
|
+
# cluster_labels = analyzer.fit_predict(samples)
|
17
|
+
#
|
18
|
+
# *Reference*
|
19
|
+
# - Carreira-Perpinan, M A., "A review of mean-shift algorithms for clustering," arXiv:1503.00687v1.
|
20
|
+
# - Sheikh, Y A., Khan, E A., and Kanade, T., "Mode-seeking by Medoidshifts," Proc. ICCV'07, pp. 1--8, 2007.
|
21
|
+
# - Vedaldi, A., and Soatto, S., "Quick Shift and Kernel Methods for Mode Seeking," Proc. ECCV'08, pp. 705--718, 2008.
|
22
|
+
class MeanShift < Rumale::Base::Estimator
|
23
|
+
include Rumale::Base::ClusterAnalyzer
|
24
|
+
|
25
|
+
# Return the centroids.
|
26
|
+
# @return [Numo::DFloat] (shape: [n_clusters, n_features])
|
27
|
+
attr_reader :cluster_centers
|
28
|
+
|
29
|
+
# Create a new cluster analyzer with mean-shift algorithm.
|
30
|
+
#
|
31
|
+
# @param bandwidth [Float] The bandwidth parameter of flat kernel.
|
32
|
+
# @param max_iter [Integer] The maximum number of iterations.
|
33
|
+
# @param tol [Float] The tolerance of termination criterion
|
34
|
+
def initialize(bandwidth: 1.0, max_iter: 500, tol: 1e-4)
|
35
|
+
super()
|
36
|
+
@params = {
|
37
|
+
bandwidth: bandwidth,
|
38
|
+
max_iter: max_iter,
|
39
|
+
tol: tol
|
40
|
+
}
|
41
|
+
end
|
42
|
+
|
43
|
+
# Analysis clusters with given training data.
|
44
|
+
#
|
45
|
+
# @overload fit(x) -> MeanShift
|
46
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
|
47
|
+
# @return [MeanShift] The learned cluster analyzer itself.
|
48
|
+
def fit(x, _y = nil)
|
49
|
+
x = Rumale::Validation.check_convert_sample_array(x)
|
50
|
+
|
51
|
+
z = x.dup
|
52
|
+
@params[:max_iter].times do
|
53
|
+
distance_mat = Rumale::PairwiseMetric.euclidean_distance(x, z)
|
54
|
+
kernel_mat = Numo::DFloat.cast(distance_mat.le(@params[:bandwidth]))
|
55
|
+
sum_kernel = kernel_mat.sum(axis: 0)
|
56
|
+
weight_mat = kernel_mat.dot((1 / sum_kernel).diag)
|
57
|
+
updated = weight_mat.transpose.dot(x)
|
58
|
+
break if (z - updated).abs.sum(axis: 1).max <= @params[:tol]
|
59
|
+
|
60
|
+
z = updated
|
61
|
+
end
|
62
|
+
|
63
|
+
@cluster_centers = connect_components(z)
|
64
|
+
|
65
|
+
self
|
66
|
+
end
|
67
|
+
|
68
|
+
# Predict cluster labels for samples.
|
69
|
+
#
|
70
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the cluster label.
|
71
|
+
# @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
|
72
|
+
def predict(x)
|
73
|
+
x = Rumale::Validation.check_convert_sample_array(x)
|
74
|
+
|
75
|
+
assign_cluster(x)
|
76
|
+
end
|
77
|
+
|
78
|
+
# Analysis clusters and assign samples to clusters.
|
79
|
+
#
|
80
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
|
81
|
+
# @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
|
82
|
+
def fit_predict(x)
|
83
|
+
x = Rumale::Validation.check_convert_sample_array(x)
|
84
|
+
|
85
|
+
fit(x).predict(x)
|
86
|
+
end
|
87
|
+
|
88
|
+
private
|
89
|
+
|
90
|
+
def assign_cluster(x)
|
91
|
+
n_clusters = @cluster_centers.shape[0]
|
92
|
+
distance_mat = Rumale::PairwiseMetric.squared_error(x, @cluster_centers)
|
93
|
+
distance_mat.min_index(axis: 1) - Numo::Int32[*0.step(distance_mat.size - 1, n_clusters)]
|
94
|
+
end
|
95
|
+
|
96
|
+
def connect_components(z)
|
97
|
+
centers = []
|
98
|
+
n_samples = z.shape[0]
|
99
|
+
|
100
|
+
n_samples.times do |idx|
|
101
|
+
assigned = false
|
102
|
+
centers.each do |cluster_vec|
|
103
|
+
dist = Math.sqrt(((z[idx, true] - cluster_vec)**2).sum.abs)
|
104
|
+
if dist <= @params[:bandwidth]
|
105
|
+
assigned = true
|
106
|
+
break
|
107
|
+
end
|
108
|
+
end
|
109
|
+
centers << z[idx, true].dup unless assigned
|
110
|
+
end
|
111
|
+
|
112
|
+
Numo::DFloat.asarray(centers)
|
113
|
+
end
|
114
|
+
end
|
115
|
+
end
|
116
|
+
end
|
data/lib/rumale/clustering.rb
CHANGED
@@ -7,6 +7,7 @@ require_relative 'clustering/gaussian_mixture'
|
|
7
7
|
require_relative 'clustering/hdbscan'
|
8
8
|
require_relative 'clustering/k_means'
|
9
9
|
require_relative 'clustering/k_medoids'
|
10
|
+
require_relative 'clustering/mean_shift'
|
10
11
|
require_relative 'clustering/mini_batch_k_means'
|
11
12
|
require_relative 'clustering/power_iteration'
|
12
13
|
require_relative 'clustering/single_linkage'
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: rumale-clustering
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.26.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- yoshoku
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2023-
|
11
|
+
date: 2023-02-19 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|
@@ -30,14 +30,14 @@ dependencies:
|
|
30
30
|
requirements:
|
31
31
|
- - "~>"
|
32
32
|
- !ruby/object:Gem::Version
|
33
|
-
version: 0.
|
33
|
+
version: 0.26.0
|
34
34
|
type: :runtime
|
35
35
|
prerelease: false
|
36
36
|
version_requirements: !ruby/object:Gem::Requirement
|
37
37
|
requirements:
|
38
38
|
- - "~>"
|
39
39
|
- !ruby/object:Gem::Version
|
40
|
-
version: 0.
|
40
|
+
version: 0.26.0
|
41
41
|
description: |
|
42
42
|
Rumale::Clustering provides cluster analysis algorithms,
|
43
43
|
such as K-Means, Gaussian Mixture Model, DBSCAN, and Spectral Clustering,
|
@@ -56,6 +56,7 @@ files:
|
|
56
56
|
- lib/rumale/clustering/hdbscan.rb
|
57
57
|
- lib/rumale/clustering/k_means.rb
|
58
58
|
- lib/rumale/clustering/k_medoids.rb
|
59
|
+
- lib/rumale/clustering/mean_shift.rb
|
59
60
|
- lib/rumale/clustering/mini_batch_k_means.rb
|
60
61
|
- lib/rumale/clustering/power_iteration.rb
|
61
62
|
- lib/rumale/clustering/single_linkage.rb
|