rumale-clustering 0.25.0 → 0.27.0
Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 45670ff85fcced403ac58edd979f2487027f6b61ca1c8b9d95eddf4e292e2f5d
|
4
|
+
data.tar.gz: 2a1f7a44e0c69555f43f3b24e487471aabd85688bf21a5d141e936baee9c27c7
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: a881089ed98394ca1ebe9f067956d9f36b33ce38b160cb95220b1fac73295eb96bd02c517acf87b49ee637ef6cc82695ac3aa324e6540b84ccda2a8cb2fb7bd4
|
7
|
+
data.tar.gz: cf6afa202648c7aed4e4194d75a13977187c5040cdfb8918155837353658e217ac6d58d58d8ea7a3b6df0d53f65fd1d732c6e1f49cd2c1a91c0ae30fe0af2c32
|
@@ -0,0 +1,116 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'rumale/base/estimator'
|
4
|
+
require 'rumale/base/cluster_analyzer'
|
5
|
+
require 'rumale/pairwise_metric'
|
6
|
+
require 'rumale/validation'
|
7
|
+
|
8
|
+
module Rumale
|
9
|
+
module Clustering
|
10
|
+
# MeanShift is a class that implements mean-shift clustering with flat kernel.
|
11
|
+
#
|
12
|
+
# @example
|
13
|
+
# require 'rumale/clustering/mean_shift'
|
14
|
+
#
|
15
|
+
# analyzer = Rumale::Clustering::MeanShift.new(bandwidth: 1.5)
|
16
|
+
# cluster_labels = analyzer.fit_predict(samples)
|
17
|
+
#
|
18
|
+
# *Reference*
|
19
|
+
# - Carreira-Perpinan, M A., "A review of mean-shift algorithms for clustering," arXiv:1503.00687v1.
|
20
|
+
# - Sheikh, Y A., Khan, E A., and Kanade, T., "Mode-seeking by Medoidshifts," Proc. ICCV'07, pp. 1--8, 2007.
|
21
|
+
# - Vedaldi, A., and Soatto, S., "Quick Shift and Kernel Methods for Mode Seeking," Proc. ECCV'08, pp. 705--718, 2008.
|
22
|
+
class MeanShift < Rumale::Base::Estimator
|
23
|
+
include Rumale::Base::ClusterAnalyzer
|
24
|
+
|
25
|
+
# Return the centroids.
|
26
|
+
# @return [Numo::DFloat] (shape: [n_clusters, n_features])
|
27
|
+
attr_reader :cluster_centers
|
28
|
+
|
29
|
+
# Create a new cluster analyzer with mean-shift algorithm.
|
30
|
+
#
|
31
|
+
# @param bandwidth [Float] The bandwidth parameter of flat kernel.
|
32
|
+
# @param max_iter [Integer] The maximum number of iterations.
|
33
|
+
# @param tol [Float] The tolerance of termination criterion
|
34
|
+
def initialize(bandwidth: 1.0, max_iter: 500, tol: 1e-4)
|
35
|
+
super()
|
36
|
+
@params = {
|
37
|
+
bandwidth: bandwidth,
|
38
|
+
max_iter: max_iter,
|
39
|
+
tol: tol
|
40
|
+
}
|
41
|
+
end
|
42
|
+
|
43
|
+
# Analysis clusters with given training data.
|
44
|
+
#
|
45
|
+
# @overload fit(x) -> MeanShift
|
46
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
|
47
|
+
# @return [MeanShift] The learned cluster analyzer itself.
|
48
|
+
def fit(x, _y = nil)
|
49
|
+
x = Rumale::Validation.check_convert_sample_array(x)
|
50
|
+
|
51
|
+
z = x.dup
|
52
|
+
@params[:max_iter].times do
|
53
|
+
distance_mat = Rumale::PairwiseMetric.euclidean_distance(x, z)
|
54
|
+
kernel_mat = Numo::DFloat.cast(distance_mat.le(@params[:bandwidth]))
|
55
|
+
sum_kernel = kernel_mat.sum(axis: 0)
|
56
|
+
weight_mat = kernel_mat.dot((1 / sum_kernel).diag)
|
57
|
+
updated = weight_mat.transpose.dot(x)
|
58
|
+
break if (z - updated).abs.sum(axis: 1).max <= @params[:tol]
|
59
|
+
|
60
|
+
z = updated
|
61
|
+
end
|
62
|
+
|
63
|
+
@cluster_centers = connect_components(z)
|
64
|
+
|
65
|
+
self
|
66
|
+
end
|
67
|
+
|
68
|
+
# Predict cluster labels for samples.
|
69
|
+
#
|
70
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the cluster label.
|
71
|
+
# @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
|
72
|
+
def predict(x)
|
73
|
+
x = Rumale::Validation.check_convert_sample_array(x)
|
74
|
+
|
75
|
+
assign_cluster(x)
|
76
|
+
end
|
77
|
+
|
78
|
+
# Analysis clusters and assign samples to clusters.
|
79
|
+
#
|
80
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
|
81
|
+
# @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
|
82
|
+
def fit_predict(x)
|
83
|
+
x = Rumale::Validation.check_convert_sample_array(x)
|
84
|
+
|
85
|
+
fit(x).predict(x)
|
86
|
+
end
|
87
|
+
|
88
|
+
private
|
89
|
+
|
90
|
+
def assign_cluster(x)
|
91
|
+
n_clusters = @cluster_centers.shape[0]
|
92
|
+
distance_mat = Rumale::PairwiseMetric.squared_error(x, @cluster_centers)
|
93
|
+
distance_mat.min_index(axis: 1) - Numo::Int32[*0.step(distance_mat.size - 1, n_clusters)]
|
94
|
+
end
|
95
|
+
|
96
|
+
def connect_components(z)
|
97
|
+
centers = []
|
98
|
+
n_samples = z.shape[0]
|
99
|
+
|
100
|
+
n_samples.times do |idx|
|
101
|
+
assigned = false
|
102
|
+
centers.each do |cluster_vec|
|
103
|
+
dist = Math.sqrt(((z[idx, true] - cluster_vec)**2).sum.abs)
|
104
|
+
if dist <= @params[:bandwidth]
|
105
|
+
assigned = true
|
106
|
+
break
|
107
|
+
end
|
108
|
+
end
|
109
|
+
centers << z[idx, true].dup unless assigned
|
110
|
+
end
|
111
|
+
|
112
|
+
Numo::DFloat.asarray(centers)
|
113
|
+
end
|
114
|
+
end
|
115
|
+
end
|
116
|
+
end
|
data/lib/rumale/clustering.rb
CHANGED
@@ -7,6 +7,7 @@ require_relative 'clustering/gaussian_mixture'
|
|
7
7
|
require_relative 'clustering/hdbscan'
|
8
8
|
require_relative 'clustering/k_means'
|
9
9
|
require_relative 'clustering/k_medoids'
|
10
|
+
require_relative 'clustering/mean_shift'
|
10
11
|
require_relative 'clustering/mini_batch_k_means'
|
11
12
|
require_relative 'clustering/power_iteration'
|
12
13
|
require_relative 'clustering/single_linkage'
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: rumale-clustering
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.27.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- yoshoku
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2023-
|
11
|
+
date: 2023-08-26 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|
@@ -30,14 +30,14 @@ dependencies:
|
|
30
30
|
requirements:
|
31
31
|
- - "~>"
|
32
32
|
- !ruby/object:Gem::Version
|
33
|
-
version: 0.
|
33
|
+
version: 0.27.0
|
34
34
|
type: :runtime
|
35
35
|
prerelease: false
|
36
36
|
version_requirements: !ruby/object:Gem::Requirement
|
37
37
|
requirements:
|
38
38
|
- - "~>"
|
39
39
|
- !ruby/object:Gem::Version
|
40
|
-
version: 0.
|
40
|
+
version: 0.27.0
|
41
41
|
description: |
|
42
42
|
Rumale::Clustering provides cluster analysis algorithms,
|
43
43
|
such as K-Means, Gaussian Mixture Model, DBSCAN, and Spectral Clustering,
|
@@ -56,6 +56,7 @@ files:
|
|
56
56
|
- lib/rumale/clustering/hdbscan.rb
|
57
57
|
- lib/rumale/clustering/k_means.rb
|
58
58
|
- lib/rumale/clustering/k_medoids.rb
|
59
|
+
- lib/rumale/clustering/mean_shift.rb
|
59
60
|
- lib/rumale/clustering/mini_batch_k_means.rb
|
60
61
|
- lib/rumale/clustering/power_iteration.rb
|
61
62
|
- lib/rumale/clustering/single_linkage.rb
|