rumale-clustering 0.24.0 → 0.26.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
 - data/LICENSE.txt +1 -1
 - data/lib/rumale/clustering/mean_shift.rb +116 -0
 - data/lib/rumale/clustering/version.rb +1 -1
 - data/lib/rumale/clustering.rb +1 -0
 - metadata +5 -4
 
    
        checksums.yaml
    CHANGED
    
    | 
         @@ -1,7 +1,7 @@ 
     | 
|
| 
       1 
1 
     | 
    
         
             
            ---
         
     | 
| 
       2 
2 
     | 
    
         
             
            SHA256:
         
     | 
| 
       3 
     | 
    
         
            -
              metadata.gz:  
     | 
| 
       4 
     | 
    
         
            -
              data.tar.gz:  
     | 
| 
      
 3 
     | 
    
         
            +
              metadata.gz: d4bd3dd2f04f44e145f7ebe90154dfd9e2970485c58cd68af82abd0f744e78c0
         
     | 
| 
      
 4 
     | 
    
         
            +
              data.tar.gz: 9af9eafaa42f596f75c09a13b6da658bf7ceb69af657a22bc50f6c58b7a2eb05
         
     | 
| 
       5 
5 
     | 
    
         
             
            SHA512:
         
     | 
| 
       6 
     | 
    
         
            -
              metadata.gz:  
     | 
| 
       7 
     | 
    
         
            -
              data.tar.gz:  
     | 
| 
      
 6 
     | 
    
         
            +
              metadata.gz: 2bcbe3e94d4ae65507fb6253b68264dcc42ed455a7150d7096ea71d4d838dae2de2ba2b7991e06b258b16e28f701f5c350b5e26a7c06f43a0b08d6e0145c5cb1
         
     | 
| 
      
 7 
     | 
    
         
            +
              data.tar.gz: c08a182fd31b16aaad51186c4dfe3457e5ea7d5f9b956a966ead434a036d9817189cd3e5ee7be824a04c06baf60771423f9ef23c0488aceb1705dfde7aaeb8a4
         
     | 
    
        data/LICENSE.txt
    CHANGED
    
    
| 
         @@ -0,0 +1,116 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            # frozen_string_literal: true
         
     | 
| 
      
 2 
     | 
    
         
            +
             
     | 
| 
      
 3 
     | 
    
         
            +
            require 'rumale/base/estimator'
         
     | 
| 
      
 4 
     | 
    
         
            +
            require 'rumale/base/cluster_analyzer'
         
     | 
| 
      
 5 
     | 
    
         
            +
            require 'rumale/pairwise_metric'
         
     | 
| 
      
 6 
     | 
    
         
            +
            require 'rumale/validation'
         
     | 
| 
      
 7 
     | 
    
         
            +
             
     | 
| 
      
 8 
     | 
    
         
            +
            module Rumale
         
     | 
| 
      
 9 
     | 
    
         
            +
              module Clustering
         
     | 
| 
      
 10 
     | 
    
         
            +
                # MeanShift is a class that implements mean-shift clustering with flat kernel.
         
     | 
| 
      
 11 
     | 
    
         
            +
                #
         
     | 
| 
      
 12 
     | 
    
         
            +
                # @example
         
     | 
| 
      
 13 
     | 
    
         
            +
                #   require 'rumale/clustering/mean_shift'
         
     | 
| 
      
 14 
     | 
    
         
            +
                #
         
     | 
| 
      
 15 
     | 
    
         
            +
                #   analyzer = Rumale::Clustering::MeanShift.new(bandwidth: 1.5)
         
     | 
| 
      
 16 
     | 
    
         
            +
                #   cluster_labels = analyzer.fit_predict(samples)
         
     | 
| 
      
 17 
     | 
    
         
            +
                #
         
     | 
| 
      
 18 
     | 
    
         
            +
                # *Reference*
         
     | 
| 
      
 19 
     | 
    
         
            +
                # - Carreira-Perpinan, M A., "A review of mean-shift algorithms for clustering," arXiv:1503.00687v1.
         
     | 
| 
      
 20 
     | 
    
         
            +
                # - Sheikh, Y A., Khan, E A., and Kanade, T., "Mode-seeking by Medoidshifts," Proc. ICCV'07, pp. 1--8, 2007.
         
     | 
| 
      
 21 
     | 
    
         
            +
                # - Vedaldi, A., and Soatto, S., "Quick Shift and Kernel Methods for Mode Seeking," Proc. ECCV'08, pp. 705--718, 2008.
         
     | 
| 
      
 22 
     | 
    
         
            +
                class MeanShift < Rumale::Base::Estimator
         
     | 
| 
      
 23 
     | 
    
         
            +
                  include Rumale::Base::ClusterAnalyzer
         
     | 
| 
      
 24 
     | 
    
         
            +
             
     | 
| 
      
 25 
     | 
    
         
            +
                  # Return the centroids.
         
     | 
| 
      
 26 
     | 
    
         
            +
                  # @return [Numo::DFloat] (shape: [n_clusters, n_features])
         
     | 
| 
      
 27 
     | 
    
         
            +
                  attr_reader :cluster_centers
         
     | 
| 
      
 28 
     | 
    
         
            +
             
     | 
| 
      
 29 
     | 
    
         
            +
                  # Create a new cluster analyzer with mean-shift algorithm.
         
     | 
| 
      
 30 
     | 
    
         
            +
                  #
         
     | 
| 
      
 31 
     | 
    
         
            +
                  # @param bandwidth [Float] The bandwidth parameter of flat kernel.
         
     | 
| 
      
 32 
     | 
    
         
            +
                  # @param max_iter [Integer] The maximum number of iterations.
         
     | 
| 
      
 33 
     | 
    
         
            +
                  # @param tol [Float] The tolerance of termination criterion
         
     | 
| 
      
 34 
     | 
    
         
            +
                  def initialize(bandwidth: 1.0, max_iter: 500, tol: 1e-4)
         
     | 
| 
      
 35 
     | 
    
         
            +
                    super()
         
     | 
| 
      
 36 
     | 
    
         
            +
                    @params = {
         
     | 
| 
      
 37 
     | 
    
         
            +
                      bandwidth: bandwidth,
         
     | 
| 
      
 38 
     | 
    
         
            +
                      max_iter: max_iter,
         
     | 
| 
      
 39 
     | 
    
         
            +
                      tol: tol
         
     | 
| 
      
 40 
     | 
    
         
            +
                    }
         
     | 
| 
      
 41 
     | 
    
         
            +
                  end
         
     | 
| 
      
 42 
     | 
    
         
            +
             
     | 
| 
      
 43 
     | 
    
         
            +
                  # Analysis clusters with given training data.
         
     | 
| 
      
 44 
     | 
    
         
            +
                  #
         
     | 
| 
      
 45 
     | 
    
         
            +
                  # @overload fit(x) -> MeanShift
         
     | 
| 
      
 46 
     | 
    
         
            +
                  #   @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
         
     | 
| 
      
 47 
     | 
    
         
            +
                  #   @return [MeanShift] The learned cluster analyzer itself.
         
     | 
| 
      
 48 
     | 
    
         
            +
                  def fit(x, _y = nil)
         
     | 
| 
      
 49 
     | 
    
         
            +
                    x = Rumale::Validation.check_convert_sample_array(x)
         
     | 
| 
      
 50 
     | 
    
         
            +
             
     | 
| 
      
 51 
     | 
    
         
            +
                    z = x.dup
         
     | 
| 
      
 52 
     | 
    
         
            +
                    @params[:max_iter].times do
         
     | 
| 
      
 53 
     | 
    
         
            +
                      distance_mat = Rumale::PairwiseMetric.euclidean_distance(x, z)
         
     | 
| 
      
 54 
     | 
    
         
            +
                      kernel_mat = Numo::DFloat.cast(distance_mat.le(@params[:bandwidth]))
         
     | 
| 
      
 55 
     | 
    
         
            +
                      sum_kernel = kernel_mat.sum(axis: 0)
         
     | 
| 
      
 56 
     | 
    
         
            +
                      weight_mat = kernel_mat.dot((1 / sum_kernel).diag)
         
     | 
| 
      
 57 
     | 
    
         
            +
                      updated = weight_mat.transpose.dot(x)
         
     | 
| 
      
 58 
     | 
    
         
            +
                      break if (z - updated).abs.sum(axis: 1).max <= @params[:tol]
         
     | 
| 
      
 59 
     | 
    
         
            +
             
     | 
| 
      
 60 
     | 
    
         
            +
                      z = updated
         
     | 
| 
      
 61 
     | 
    
         
            +
                    end
         
     | 
| 
      
 62 
     | 
    
         
            +
             
     | 
| 
      
 63 
     | 
    
         
            +
                    @cluster_centers = connect_components(z)
         
     | 
| 
      
 64 
     | 
    
         
            +
             
     | 
| 
      
 65 
     | 
    
         
            +
                    self
         
     | 
| 
      
 66 
     | 
    
         
            +
                  end
         
     | 
| 
      
 67 
     | 
    
         
            +
             
     | 
| 
      
 68 
     | 
    
         
            +
                  # Predict cluster labels for samples.
         
     | 
| 
      
 69 
     | 
    
         
            +
                  #
         
     | 
| 
      
 70 
     | 
    
         
            +
                  # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the cluster label.
         
     | 
| 
      
 71 
     | 
    
         
            +
                  # @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
         
     | 
| 
      
 72 
     | 
    
         
            +
                  def predict(x)
         
     | 
| 
      
 73 
     | 
    
         
            +
                    x = Rumale::Validation.check_convert_sample_array(x)
         
     | 
| 
      
 74 
     | 
    
         
            +
             
     | 
| 
      
 75 
     | 
    
         
            +
                    assign_cluster(x)
         
     | 
| 
      
 76 
     | 
    
         
            +
                  end
         
     | 
| 
      
 77 
     | 
    
         
            +
             
     | 
| 
      
 78 
     | 
    
         
            +
                  # Analysis clusters and assign samples to clusters.
         
     | 
| 
      
 79 
     | 
    
         
            +
                  #
         
     | 
| 
      
 80 
     | 
    
         
            +
                  # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
         
     | 
| 
      
 81 
     | 
    
         
            +
                  # @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
         
     | 
| 
      
 82 
     | 
    
         
            +
                  def fit_predict(x)
         
     | 
| 
      
 83 
     | 
    
         
            +
                    x = Rumale::Validation.check_convert_sample_array(x)
         
     | 
| 
      
 84 
     | 
    
         
            +
             
     | 
| 
      
 85 
     | 
    
         
            +
                    fit(x).predict(x)
         
     | 
| 
      
 86 
     | 
    
         
            +
                  end
         
     | 
| 
      
 87 
     | 
    
         
            +
             
     | 
| 
      
 88 
     | 
    
         
            +
                  private
         
     | 
| 
      
 89 
     | 
    
         
            +
             
     | 
| 
      
 90 
     | 
    
         
            +
                  def assign_cluster(x)
         
     | 
| 
      
 91 
     | 
    
         
            +
                    n_clusters = @cluster_centers.shape[0]
         
     | 
| 
      
 92 
     | 
    
         
            +
                    distance_mat = Rumale::PairwiseMetric.squared_error(x, @cluster_centers)
         
     | 
| 
      
 93 
     | 
    
         
            +
                    distance_mat.min_index(axis: 1) - Numo::Int32[*0.step(distance_mat.size - 1, n_clusters)]
         
     | 
| 
      
 94 
     | 
    
         
            +
                  end
         
     | 
| 
      
 95 
     | 
    
         
            +
             
     | 
| 
      
 96 
     | 
    
         
            +
                  def connect_components(z)
         
     | 
| 
      
 97 
     | 
    
         
            +
                    centers = []
         
     | 
| 
      
 98 
     | 
    
         
            +
                    n_samples = z.shape[0]
         
     | 
| 
      
 99 
     | 
    
         
            +
             
     | 
| 
      
 100 
     | 
    
         
            +
                    n_samples.times do |idx|
         
     | 
| 
      
 101 
     | 
    
         
            +
                      assigned = false
         
     | 
| 
      
 102 
     | 
    
         
            +
                      centers.each do |cluster_vec|
         
     | 
| 
      
 103 
     | 
    
         
            +
                        dist = Math.sqrt(((z[idx, true] - cluster_vec)**2).sum.abs)
         
     | 
| 
      
 104 
     | 
    
         
            +
                        if dist <= @params[:bandwidth]
         
     | 
| 
      
 105 
     | 
    
         
            +
                          assigned = true
         
     | 
| 
      
 106 
     | 
    
         
            +
                          break
         
     | 
| 
      
 107 
     | 
    
         
            +
                        end
         
     | 
| 
      
 108 
     | 
    
         
            +
                      end
         
     | 
| 
      
 109 
     | 
    
         
            +
                      centers << z[idx, true].dup unless assigned
         
     | 
| 
      
 110 
     | 
    
         
            +
                    end
         
     | 
| 
      
 111 
     | 
    
         
            +
             
     | 
| 
      
 112 
     | 
    
         
            +
                    Numo::DFloat.asarray(centers)
         
     | 
| 
      
 113 
     | 
    
         
            +
                  end
         
     | 
| 
      
 114 
     | 
    
         
            +
                end
         
     | 
| 
      
 115 
     | 
    
         
            +
              end
         
     | 
| 
      
 116 
     | 
    
         
            +
            end
         
     | 
    
        data/lib/rumale/clustering.rb
    CHANGED
    
    | 
         @@ -7,6 +7,7 @@ require_relative 'clustering/gaussian_mixture' 
     | 
|
| 
       7 
7 
     | 
    
         
             
            require_relative 'clustering/hdbscan'
         
     | 
| 
       8 
8 
     | 
    
         
             
            require_relative 'clustering/k_means'
         
     | 
| 
       9 
9 
     | 
    
         
             
            require_relative 'clustering/k_medoids'
         
     | 
| 
      
 10 
     | 
    
         
            +
            require_relative 'clustering/mean_shift'
         
     | 
| 
       10 
11 
     | 
    
         
             
            require_relative 'clustering/mini_batch_k_means'
         
     | 
| 
       11 
12 
     | 
    
         
             
            require_relative 'clustering/power_iteration'
         
     | 
| 
       12 
13 
     | 
    
         
             
            require_relative 'clustering/single_linkage'
         
     | 
    
        metadata
    CHANGED
    
    | 
         @@ -1,14 +1,14 @@ 
     | 
|
| 
       1 
1 
     | 
    
         
             
            --- !ruby/object:Gem::Specification
         
     | 
| 
       2 
2 
     | 
    
         
             
            name: rumale-clustering
         
     | 
| 
       3 
3 
     | 
    
         
             
            version: !ruby/object:Gem::Version
         
     | 
| 
       4 
     | 
    
         
            -
              version: 0. 
     | 
| 
      
 4 
     | 
    
         
            +
              version: 0.26.0
         
     | 
| 
       5 
5 
     | 
    
         
             
            platform: ruby
         
     | 
| 
       6 
6 
     | 
    
         
             
            authors:
         
     | 
| 
       7 
7 
     | 
    
         
             
            - yoshoku
         
     | 
| 
       8 
8 
     | 
    
         
             
            autorequire:
         
     | 
| 
       9 
9 
     | 
    
         
             
            bindir: exe
         
     | 
| 
       10 
10 
     | 
    
         
             
            cert_chain: []
         
     | 
| 
       11 
     | 
    
         
            -
            date:  
     | 
| 
      
 11 
     | 
    
         
            +
            date: 2023-02-19 00:00:00.000000000 Z
         
     | 
| 
       12 
12 
     | 
    
         
             
            dependencies:
         
     | 
| 
       13 
13 
     | 
    
         
             
            - !ruby/object:Gem::Dependency
         
     | 
| 
       14 
14 
     | 
    
         
             
              name: numo-narray
         
     | 
| 
         @@ -30,14 +30,14 @@ dependencies: 
     | 
|
| 
       30 
30 
     | 
    
         
             
                requirements:
         
     | 
| 
       31 
31 
     | 
    
         
             
                - - "~>"
         
     | 
| 
       32 
32 
     | 
    
         
             
                  - !ruby/object:Gem::Version
         
     | 
| 
       33 
     | 
    
         
            -
                    version: 0. 
     | 
| 
      
 33 
     | 
    
         
            +
                    version: 0.26.0
         
     | 
| 
       34 
34 
     | 
    
         
             
              type: :runtime
         
     | 
| 
       35 
35 
     | 
    
         
             
              prerelease: false
         
     | 
| 
       36 
36 
     | 
    
         
             
              version_requirements: !ruby/object:Gem::Requirement
         
     | 
| 
       37 
37 
     | 
    
         
             
                requirements:
         
     | 
| 
       38 
38 
     | 
    
         
             
                - - "~>"
         
     | 
| 
       39 
39 
     | 
    
         
             
                  - !ruby/object:Gem::Version
         
     | 
| 
       40 
     | 
    
         
            -
                    version: 0. 
     | 
| 
      
 40 
     | 
    
         
            +
                    version: 0.26.0
         
     | 
| 
       41 
41 
     | 
    
         
             
            description: |
         
     | 
| 
       42 
42 
     | 
    
         
             
              Rumale::Clustering provides cluster analysis algorithms,
         
     | 
| 
       43 
43 
     | 
    
         
             
              such as K-Means, Gaussian Mixture Model, DBSCAN, and Spectral Clustering,
         
     | 
| 
         @@ -56,6 +56,7 @@ files: 
     | 
|
| 
       56 
56 
     | 
    
         
             
            - lib/rumale/clustering/hdbscan.rb
         
     | 
| 
       57 
57 
     | 
    
         
             
            - lib/rumale/clustering/k_means.rb
         
     | 
| 
       58 
58 
     | 
    
         
             
            - lib/rumale/clustering/k_medoids.rb
         
     | 
| 
      
 59 
     | 
    
         
            +
            - lib/rumale/clustering/mean_shift.rb
         
     | 
| 
       59 
60 
     | 
    
         
             
            - lib/rumale/clustering/mini_batch_k_means.rb
         
     | 
| 
       60 
61 
     | 
    
         
             
            - lib/rumale/clustering/power_iteration.rb
         
     | 
| 
       61 
62 
     | 
    
         
             
            - lib/rumale/clustering/single_linkage.rb
         
     |