rumale-clustering 0.25.0 → 0.26.0

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 6ee5c0e4616d465c45562131bc0b0fd2c88b7e0725c45e9ef2864c271fec9f50
4
- data.tar.gz: c8261c272928d2cdf2ae7c1938373c899b10dc15ee5bdd3cc6e28b0ca32a7040
3
+ metadata.gz: d4bd3dd2f04f44e145f7ebe90154dfd9e2970485c58cd68af82abd0f744e78c0
4
+ data.tar.gz: 9af9eafaa42f596f75c09a13b6da658bf7ceb69af657a22bc50f6c58b7a2eb05
5
5
  SHA512:
6
- metadata.gz: 67d2413b7f74868c63df550611f753fb37bf1bab6b5c8ceaf024193083d5b767b38ffabee22b260ce18dfeac417f9c4ccd659a93f5f092db0070d388911a540a
7
- data.tar.gz: 5f89ecc7ae05d4c990bf90eba3e7a805b4c176c6217ac668690491d3c429ef5a3d028edbd307306cc65c0abd1b06b7b42436c01b4b6ac8d2ecdfd97099f84fab
6
+ metadata.gz: 2bcbe3e94d4ae65507fb6253b68264dcc42ed455a7150d7096ea71d4d838dae2de2ba2b7991e06b258b16e28f701f5c350b5e26a7c06f43a0b08d6e0145c5cb1
7
+ data.tar.gz: c08a182fd31b16aaad51186c4dfe3457e5ea7d5f9b956a966ead434a036d9817189cd3e5ee7be824a04c06baf60771423f9ef23c0488aceb1705dfde7aaeb8a4
@@ -0,0 +1,116 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/estimator'
4
+ require 'rumale/base/cluster_analyzer'
5
+ require 'rumale/pairwise_metric'
6
+ require 'rumale/validation'
7
+
8
+ module Rumale
9
+ module Clustering
10
+ # MeanShift is a class that implements mean-shift clustering with flat kernel.
11
+ #
12
+ # @example
13
+ # require 'rumale/clustering/mean_shift'
14
+ #
15
+ # analyzer = Rumale::Clustering::MeanShift.new(bandwidth: 1.5)
16
+ # cluster_labels = analyzer.fit_predict(samples)
17
+ #
18
+ # *Reference*
19
+ # - Carreira-Perpinan, M A., "A review of mean-shift algorithms for clustering," arXiv:1503.00687v1.
20
+ # - Sheikh, Y A., Khan, E A., and Kanade, T., "Mode-seeking by Medoidshifts," Proc. ICCV'07, pp. 1--8, 2007.
21
+ # - Vedaldi, A., and Soatto, S., "Quick Shift and Kernel Methods for Mode Seeking," Proc. ECCV'08, pp. 705--718, 2008.
22
+ class MeanShift < Rumale::Base::Estimator
23
+ include Rumale::Base::ClusterAnalyzer
24
+
25
+ # Return the centroids.
26
+ # @return [Numo::DFloat] (shape: [n_clusters, n_features])
27
+ attr_reader :cluster_centers
28
+
29
+ # Create a new cluster analyzer with mean-shift algorithm.
30
+ #
31
+ # @param bandwidth [Float] The bandwidth parameter of flat kernel.
32
+ # @param max_iter [Integer] The maximum number of iterations.
33
+ # @param tol [Float] The tolerance of termination criterion
34
+ def initialize(bandwidth: 1.0, max_iter: 500, tol: 1e-4)
35
+ super()
36
+ @params = {
37
+ bandwidth: bandwidth,
38
+ max_iter: max_iter,
39
+ tol: tol
40
+ }
41
+ end
42
+
43
+ # Analysis clusters with given training data.
44
+ #
45
+ # @overload fit(x) -> MeanShift
46
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
47
+ # @return [MeanShift] The learned cluster analyzer itself.
48
+ def fit(x, _y = nil)
49
+ x = Rumale::Validation.check_convert_sample_array(x)
50
+
51
+ z = x.dup
52
+ @params[:max_iter].times do
53
+ distance_mat = Rumale::PairwiseMetric.euclidean_distance(x, z)
54
+ kernel_mat = Numo::DFloat.cast(distance_mat.le(@params[:bandwidth]))
55
+ sum_kernel = kernel_mat.sum(axis: 0)
56
+ weight_mat = kernel_mat.dot((1 / sum_kernel).diag)
57
+ updated = weight_mat.transpose.dot(x)
58
+ break if (z - updated).abs.sum(axis: 1).max <= @params[:tol]
59
+
60
+ z = updated
61
+ end
62
+
63
+ @cluster_centers = connect_components(z)
64
+
65
+ self
66
+ end
67
+
68
+ # Predict cluster labels for samples.
69
+ #
70
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the cluster label.
71
+ # @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
72
+ def predict(x)
73
+ x = Rumale::Validation.check_convert_sample_array(x)
74
+
75
+ assign_cluster(x)
76
+ end
77
+
78
+ # Analysis clusters and assign samples to clusters.
79
+ #
80
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
81
+ # @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
82
+ def fit_predict(x)
83
+ x = Rumale::Validation.check_convert_sample_array(x)
84
+
85
+ fit(x).predict(x)
86
+ end
87
+
88
+ private
89
+
90
+ def assign_cluster(x)
91
+ n_clusters = @cluster_centers.shape[0]
92
+ distance_mat = Rumale::PairwiseMetric.squared_error(x, @cluster_centers)
93
+ distance_mat.min_index(axis: 1) - Numo::Int32[*0.step(distance_mat.size - 1, n_clusters)]
94
+ end
95
+
96
+ def connect_components(z)
97
+ centers = []
98
+ n_samples = z.shape[0]
99
+
100
+ n_samples.times do |idx|
101
+ assigned = false
102
+ centers.each do |cluster_vec|
103
+ dist = Math.sqrt(((z[idx, true] - cluster_vec)**2).sum.abs)
104
+ if dist <= @params[:bandwidth]
105
+ assigned = true
106
+ break
107
+ end
108
+ end
109
+ centers << z[idx, true].dup unless assigned
110
+ end
111
+
112
+ Numo::DFloat.asarray(centers)
113
+ end
114
+ end
115
+ end
116
+ end
@@ -5,6 +5,6 @@ module Rumale
5
5
  # This module consists of classes that implement cluster analysis methods.
6
6
  module Clustering
7
7
  # @!visibility private
8
- VERSION = '0.25.0'
8
+ VERSION = '0.26.0'
9
9
  end
10
10
  end
@@ -7,6 +7,7 @@ require_relative 'clustering/gaussian_mixture'
7
7
  require_relative 'clustering/hdbscan'
8
8
  require_relative 'clustering/k_means'
9
9
  require_relative 'clustering/k_medoids'
10
+ require_relative 'clustering/mean_shift'
10
11
  require_relative 'clustering/mini_batch_k_means'
11
12
  require_relative 'clustering/power_iteration'
12
13
  require_relative 'clustering/single_linkage'
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: rumale-clustering
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.25.0
4
+ version: 0.26.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2023-01-18 00:00:00.000000000 Z
11
+ date: 2023-02-19 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -30,14 +30,14 @@ dependencies:
30
30
  requirements:
31
31
  - - "~>"
32
32
  - !ruby/object:Gem::Version
33
- version: 0.25.0
33
+ version: 0.26.0
34
34
  type: :runtime
35
35
  prerelease: false
36
36
  version_requirements: !ruby/object:Gem::Requirement
37
37
  requirements:
38
38
  - - "~>"
39
39
  - !ruby/object:Gem::Version
40
- version: 0.25.0
40
+ version: 0.26.0
41
41
  description: |
42
42
  Rumale::Clustering provides cluster analysis algorithms,
43
43
  such as K-Means, Gaussian Mixture Model, DBSCAN, and Spectral Clustering,
@@ -56,6 +56,7 @@ files:
56
56
  - lib/rumale/clustering/hdbscan.rb
57
57
  - lib/rumale/clustering/k_means.rb
58
58
  - lib/rumale/clustering/k_medoids.rb
59
+ - lib/rumale/clustering/mean_shift.rb
59
60
  - lib/rumale/clustering/mini_batch_k_means.rb
60
61
  - lib/rumale/clustering/power_iteration.rb
61
62
  - lib/rumale/clustering/single_linkage.rb