rumale-clustering 0.25.0 → 0.26.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 6ee5c0e4616d465c45562131bc0b0fd2c88b7e0725c45e9ef2864c271fec9f50
4
- data.tar.gz: c8261c272928d2cdf2ae7c1938373c899b10dc15ee5bdd3cc6e28b0ca32a7040
3
+ metadata.gz: d4bd3dd2f04f44e145f7ebe90154dfd9e2970485c58cd68af82abd0f744e78c0
4
+ data.tar.gz: 9af9eafaa42f596f75c09a13b6da658bf7ceb69af657a22bc50f6c58b7a2eb05
5
5
  SHA512:
6
- metadata.gz: 67d2413b7f74868c63df550611f753fb37bf1bab6b5c8ceaf024193083d5b767b38ffabee22b260ce18dfeac417f9c4ccd659a93f5f092db0070d388911a540a
7
- data.tar.gz: 5f89ecc7ae05d4c990bf90eba3e7a805b4c176c6217ac668690491d3c429ef5a3d028edbd307306cc65c0abd1b06b7b42436c01b4b6ac8d2ecdfd97099f84fab
6
+ metadata.gz: 2bcbe3e94d4ae65507fb6253b68264dcc42ed455a7150d7096ea71d4d838dae2de2ba2b7991e06b258b16e28f701f5c350b5e26a7c06f43a0b08d6e0145c5cb1
7
+ data.tar.gz: c08a182fd31b16aaad51186c4dfe3457e5ea7d5f9b956a966ead434a036d9817189cd3e5ee7be824a04c06baf60771423f9ef23c0488aceb1705dfde7aaeb8a4
@@ -0,0 +1,116 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/estimator'
4
+ require 'rumale/base/cluster_analyzer'
5
+ require 'rumale/pairwise_metric'
6
+ require 'rumale/validation'
7
+
8
+ module Rumale
9
+ module Clustering
10
+ # MeanShift is a class that implements mean-shift clustering with flat kernel.
11
+ #
12
+ # @example
13
+ # require 'rumale/clustering/mean_shift'
14
+ #
15
+ # analyzer = Rumale::Clustering::MeanShift.new(bandwidth: 1.5)
16
+ # cluster_labels = analyzer.fit_predict(samples)
17
+ #
18
+ # *Reference*
19
+ # - Carreira-Perpinan, M A., "A review of mean-shift algorithms for clustering," arXiv:1503.00687v1.
20
+ # - Sheikh, Y A., Khan, E A., and Kanade, T., "Mode-seeking by Medoidshifts," Proc. ICCV'07, pp. 1--8, 2007.
21
+ # - Vedaldi, A., and Soatto, S., "Quick Shift and Kernel Methods for Mode Seeking," Proc. ECCV'08, pp. 705--718, 2008.
22
+ class MeanShift < Rumale::Base::Estimator
23
+ include Rumale::Base::ClusterAnalyzer
24
+
25
+ # Return the centroids.
26
+ # @return [Numo::DFloat] (shape: [n_clusters, n_features])
27
+ attr_reader :cluster_centers
28
+
29
+ # Create a new cluster analyzer with mean-shift algorithm.
30
+ #
31
+ # @param bandwidth [Float] The bandwidth parameter of flat kernel.
32
+ # @param max_iter [Integer] The maximum number of iterations.
33
+ # @param tol [Float] The tolerance of termination criterion
34
+ def initialize(bandwidth: 1.0, max_iter: 500, tol: 1e-4)
35
+ super()
36
+ @params = {
37
+ bandwidth: bandwidth,
38
+ max_iter: max_iter,
39
+ tol: tol
40
+ }
41
+ end
42
+
43
+ # Analysis clusters with given training data.
44
+ #
45
+ # @overload fit(x) -> MeanShift
46
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
47
+ # @return [MeanShift] The learned cluster analyzer itself.
48
+ def fit(x, _y = nil)
49
+ x = Rumale::Validation.check_convert_sample_array(x)
50
+
51
+ z = x.dup
52
+ @params[:max_iter].times do
53
+ distance_mat = Rumale::PairwiseMetric.euclidean_distance(x, z)
54
+ kernel_mat = Numo::DFloat.cast(distance_mat.le(@params[:bandwidth]))
55
+ sum_kernel = kernel_mat.sum(axis: 0)
56
+ weight_mat = kernel_mat.dot((1 / sum_kernel).diag)
57
+ updated = weight_mat.transpose.dot(x)
58
+ break if (z - updated).abs.sum(axis: 1).max <= @params[:tol]
59
+
60
+ z = updated
61
+ end
62
+
63
+ @cluster_centers = connect_components(z)
64
+
65
+ self
66
+ end
67
+
68
+ # Predict cluster labels for samples.
69
+ #
70
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the cluster label.
71
+ # @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
72
+ def predict(x)
73
+ x = Rumale::Validation.check_convert_sample_array(x)
74
+
75
+ assign_cluster(x)
76
+ end
77
+
78
+ # Analysis clusters and assign samples to clusters.
79
+ #
80
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
81
+ # @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
82
+ def fit_predict(x)
83
+ x = Rumale::Validation.check_convert_sample_array(x)
84
+
85
+ fit(x).predict(x)
86
+ end
87
+
88
+ private
89
+
90
+ def assign_cluster(x)
91
+ n_clusters = @cluster_centers.shape[0]
92
+ distance_mat = Rumale::PairwiseMetric.squared_error(x, @cluster_centers)
93
+ distance_mat.min_index(axis: 1) - Numo::Int32[*0.step(distance_mat.size - 1, n_clusters)]
94
+ end
95
+
96
+ def connect_components(z)
97
+ centers = []
98
+ n_samples = z.shape[0]
99
+
100
+ n_samples.times do |idx|
101
+ assigned = false
102
+ centers.each do |cluster_vec|
103
+ dist = Math.sqrt(((z[idx, true] - cluster_vec)**2).sum.abs)
104
+ if dist <= @params[:bandwidth]
105
+ assigned = true
106
+ break
107
+ end
108
+ end
109
+ centers << z[idx, true].dup unless assigned
110
+ end
111
+
112
+ Numo::DFloat.asarray(centers)
113
+ end
114
+ end
115
+ end
116
+ end
@@ -5,6 +5,6 @@ module Rumale
5
5
  # This module consists of classes that implement cluster analysis methods.
6
6
  module Clustering
7
7
  # @!visibility private
8
- VERSION = '0.25.0'
8
+ VERSION = '0.26.0'
9
9
  end
10
10
  end
@@ -7,6 +7,7 @@ require_relative 'clustering/gaussian_mixture'
7
7
  require_relative 'clustering/hdbscan'
8
8
  require_relative 'clustering/k_means'
9
9
  require_relative 'clustering/k_medoids'
10
+ require_relative 'clustering/mean_shift'
10
11
  require_relative 'clustering/mini_batch_k_means'
11
12
  require_relative 'clustering/power_iteration'
12
13
  require_relative 'clustering/single_linkage'
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: rumale-clustering
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.25.0
4
+ version: 0.26.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2023-01-18 00:00:00.000000000 Z
11
+ date: 2023-02-19 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -30,14 +30,14 @@ dependencies:
30
30
  requirements:
31
31
  - - "~>"
32
32
  - !ruby/object:Gem::Version
33
- version: 0.25.0
33
+ version: 0.26.0
34
34
  type: :runtime
35
35
  prerelease: false
36
36
  version_requirements: !ruby/object:Gem::Requirement
37
37
  requirements:
38
38
  - - "~>"
39
39
  - !ruby/object:Gem::Version
40
- version: 0.25.0
40
+ version: 0.26.0
41
41
  description: |
42
42
  Rumale::Clustering provides cluster analysis algorithms,
43
43
  such as K-Means, Gaussian Mixture Model, DBSCAN, and Spectral Clustering,
@@ -56,6 +56,7 @@ files:
56
56
  - lib/rumale/clustering/hdbscan.rb
57
57
  - lib/rumale/clustering/k_means.rb
58
58
  - lib/rumale/clustering/k_medoids.rb
59
+ - lib/rumale/clustering/mean_shift.rb
59
60
  - lib/rumale/clustering/mini_batch_k_means.rb
60
61
  - lib/rumale/clustering/power_iteration.rb
61
62
  - lib/rumale/clustering/single_linkage.rb