rumale 0.19.0 → 0.19.1

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 1cd1cdc16e6c72743d064db7d254c74eb98ca33e97cfdf9e8a76cc1fbe5dd29b
4
- data.tar.gz: 2077cae2629f2c403cc0afc415dc4b4151a2eac2ab7a3230402bf761bb653829
3
+ metadata.gz: f49170105721cfebcae9f1a424e9a858650d78225541a8cb63b0ad4c70734988
4
+ data.tar.gz: ecc35086328eee1066252e75b8cd638256039e93beebc0bce5714493fe72570b
5
5
  SHA512:
6
- metadata.gz: 2bbcdce6d0a31c95500a81a7d4a55407786068fac65ce1e5ede1bc3f56d97b2ec93fd9ca1dc52fc1a24782dba469099b25d0398a5993716da011851f18f8179c
7
- data.tar.gz: cc9fc19ea73dfa76e8ede18df7cb57f931cccdea2c546f414746ed681afbd272e3866913c75534bbdce6b275d057baa2c793b4046c8e46f697e30d5b87dba066
6
+ metadata.gz: 68f432bb34ff6c8e467a91d7c7e3aa07e816c2dd8807defc9e4e82e7a720c925062dbd27c8a7ec3294ecef2d71041baead2510edaf03a1eee210dc811eede22d
7
+ data.tar.gz: 5854eacc12de6c3cdcdbab0f9b4e73fc64d1be0533732348da6b4d6dcb0be9f115e2415501b05148fd021fa844ac0c25adc1bb858432a02ca6fe19d30a3538c7
@@ -1,3 +1,7 @@
1
+ # 0.19.1
2
+ - Add cluster analysis class for [mini-batch K-Means](https://yoshoku.github.io/rumale/doc/Rumale/Clustering/MiniBatchKMeans.html).
3
+ - Fix some typos.
4
+
1
5
  # 0.19.0
2
6
  - Change mmh3 and mopti gem to non-runtime dependent library.
3
7
  - The mmh3 gem is used in [FeatureHasher](https://yoshoku.github.io/rumale/doc/Rumale/FeatureExtraction/FeatureHasher.html).
@@ -70,6 +70,7 @@ require 'rumale/ensemble/random_forest_regressor'
70
70
  require 'rumale/ensemble/extra_trees_classifier'
71
71
  require 'rumale/ensemble/extra_trees_regressor'
72
72
  require 'rumale/clustering/k_means'
73
+ require 'rumale/clustering/mini_batch_k_means'
73
74
  require 'rumale/clustering/k_medoids'
74
75
  require 'rumale/clustering/gaussian_mixture'
75
76
  require 'rumale/clustering/dbscan'
@@ -0,0 +1,139 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/base_estimator'
4
+ require 'rumale/base/cluster_analyzer'
5
+ require 'rumale/pairwise_metric'
6
+
7
+ module Rumale
8
+ module Clustering
9
+ # MniBatchKMeans is a class that implements K-Means cluster analysis
10
+ # with mini-batch stochastic gradient descent (SGD).
11
+ #
12
+ # @example
13
+ # analyzer = Rumale::Clustering::MiniBatchKMeans.new(n_clusters: 10, max_iter: 50, batch_size: 50, random_seed: 1)
14
+ # cluster_labels = analyzer.fit_predict(samples)
15
+ #
16
+ # *Reference*
17
+ # - Sculley, D., "Web-scale k-means clustering," Proc. WWW'10, pp. 1177--1178, 2010.
18
+ class MiniBatchKMeans
19
+ include Base::BaseEstimator
20
+ include Base::ClusterAnalyzer
21
+
22
+ # Return the centroids.
23
+ # @return [Numo::DFloat] (shape: [n_clusters, n_features])
24
+ attr_reader :cluster_centers
25
+
26
+ # Return the random generator.
27
+ # @return [Random]
28
+ attr_reader :rng
29
+
30
+ # Create a new cluster analyzer with K-Means method with mini-batch SGD.
31
+ #
32
+ # @param n_clusters [Integer] The number of clusters.
33
+ # @param init [String] The initialization method for centroids ('random' or 'k-means++').
34
+ # @param max_iter [Integer] The maximum number of iterations.
35
+ # @param batch_size [Integer] The size of the mini batches.
36
+ # @param tol [Float] The tolerance of termination criterion.
37
+ # @param random_seed [Integer] The seed value using to initialize the random generator.
38
+ def initialize(n_clusters: 8, init: 'k-means++', max_iter: 100, batch_size: 100, tol: 1.0e-4, random_seed: nil)
39
+ check_params_numeric(n_clusters: n_clusters, max_iter: max_iter, batch_size: batch_size, tol: tol)
40
+ check_params_string(init: init)
41
+ check_params_numeric_or_nil(random_seed: random_seed)
42
+ check_params_positive(n_clusters: n_clusters, max_iter: max_iter)
43
+ @params = {}
44
+ @params[:n_clusters] = n_clusters
45
+ @params[:init] = init == 'random' ? 'random' : 'k-means++'
46
+ @params[:max_iter] = max_iter
47
+ @params[:batch_size] = batch_size
48
+ @params[:tol] = tol
49
+ @params[:random_seed] = random_seed
50
+ @params[:random_seed] ||= srand
51
+ @cluster_centers = nil
52
+ @rng = Random.new(@params[:random_seed])
53
+ end
54
+
55
+ # Analysis clusters with given training data.
56
+ #
57
+ # @overload fit(x) -> MiniBatchKMeans
58
+ #
59
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
60
+ # @return [KMeans] The learned cluster analyzer itself.
61
+ def fit(x, _y = nil)
62
+ x = check_convert_sample_array(x)
63
+ # initialization.
64
+ n_samples = x.shape[0]
65
+ update_counter = Numo::Int32.zeros(@params[:n_clusters])
66
+ sub_rng = @rng.dup
67
+ init_cluster_centers(x, sub_rng)
68
+ # optimization with mini-batch sgd.
69
+ @params[:max_iter].times do |_t|
70
+ sample_ids = [*0...n_samples].shuffle(random: sub_rng)
71
+ old_centers = @cluster_centers.dup
72
+ until (subset_ids = sample_ids.shift(@params[:batch_size])).empty?
73
+ # sub sampling
74
+ sub_x = x[subset_ids, true]
75
+ # assign nearest centroids
76
+ cluster_labels = assign_cluster(sub_x)
77
+ # update centroids
78
+ @params[:n_clusters].times do |c|
79
+ assigned_bits = cluster_labels.eq(c)
80
+ next unless assigned_bits.count.positive?
81
+
82
+ update_counter[c] += 1
83
+ learning_rate = 1.fdiv(update_counter[c])
84
+ update = sub_x[assigned_bits.where, true].mean(axis: 0)
85
+ @cluster_centers[c, true] = (1 - learning_rate) * @cluster_centers[c, true] + learning_rate * update
86
+ end
87
+ end
88
+ error = Numo::NMath.sqrt(((old_centers - @cluster_centers)**2).sum(axis: 1)).mean
89
+ break if error <= @params[:tol]
90
+ end
91
+ self
92
+ end
93
+
94
+ # Predict cluster labels for samples.
95
+ #
96
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the cluster label.
97
+ # @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
98
+ def predict(x)
99
+ x = check_convert_sample_array(x)
100
+ assign_cluster(x)
101
+ end
102
+
103
+ # Analysis clusters and assign samples to clusters.
104
+ #
105
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
106
+ # @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
107
+ def fit_predict(x)
108
+ x = check_convert_sample_array(x)
109
+ fit(x)
110
+ predict(x)
111
+ end
112
+
113
+ private
114
+
115
+ def assign_cluster(x)
116
+ distance_matrix = PairwiseMetric.euclidean_distance(x, @cluster_centers)
117
+ distance_matrix.min_index(axis: 1) - Numo::Int32[*0.step(distance_matrix.size - 1, @cluster_centers.shape[0])]
118
+ end
119
+
120
+ def init_cluster_centers(x, sub_rng)
121
+ # random initialize
122
+ n_samples = x.shape[0]
123
+ rand_id = [*0...n_samples].sample(@params[:n_clusters], random: sub_rng)
124
+ @cluster_centers = x[rand_id, true].dup
125
+ return unless @params[:init] == 'k-means++'
126
+
127
+ # k-means++ initialize
128
+ (1...@params[:n_clusters]).each do |n|
129
+ distance_matrix = PairwiseMetric.euclidean_distance(x, @cluster_centers[0...n, true])
130
+ min_distances = distance_matrix.flatten[distance_matrix.min_index(axis: 1)]
131
+ probs = min_distances**2 / (min_distances**2).sum
132
+ cum_probs = probs.cumsum
133
+ selected_id = cum_probs.gt(sub_rng.rand).where.to_a.first
134
+ @cluster_centers[n, true] = x[selected_id, true].dup
135
+ end
136
+ end
137
+ end
138
+ end
139
+ end
@@ -32,7 +32,7 @@ module Rumale
32
32
  end
33
33
 
34
34
  # @!visibility private
35
- # Calculate the updated weight with Nadam adaptive learning rate.
35
+ # Calculate the updated weight with Adam adaptive learning rate.
36
36
  #
37
37
  # @param weight [Numo::DFloat] (shape: [n_features]) The weight to be updated.
38
38
  # @param gradient [Numo::DFloat] (shape: [n_features]) The gradient for updating the weight.
@@ -3,5 +3,5 @@
3
3
  # Rumale is a machine learning library in Ruby.
4
4
  module Rumale
5
5
  # The version of Rumale you are using.
6
- VERSION = '0.19.0'
6
+ VERSION = '0.19.1'
7
7
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: rumale
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.19.0
4
+ version: 0.19.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
8
- autorequire:
8
+ autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2020-05-23 00:00:00.000000000 Z
11
+ date: 2020-06-06 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -72,6 +72,7 @@ files:
72
72
  - lib/rumale/clustering/hdbscan.rb
73
73
  - lib/rumale/clustering/k_means.rb
74
74
  - lib/rumale/clustering/k_medoids.rb
75
+ - lib/rumale/clustering/mini_batch_k_means.rb
75
76
  - lib/rumale/clustering/power_iteration.rb
76
77
  - lib/rumale/clustering/single_linkage.rb
77
78
  - lib/rumale/clustering/snn.rb
@@ -196,7 +197,7 @@ metadata:
196
197
  source_code_uri: https://github.com/yoshoku/rumale
197
198
  documentation_uri: https://yoshoku.github.io/rumale/doc/
198
199
  bug_tracker_uri: https://github.com/yoshoku/rumale/issues
199
- post_install_message:
200
+ post_install_message:
200
201
  rdoc_options: []
201
202
  require_paths:
202
203
  - lib
@@ -212,7 +213,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
212
213
  version: '0'
213
214
  requirements: []
214
215
  rubygems_version: 3.1.2
215
- signing_key:
216
+ signing_key:
216
217
  specification_version: 4
217
218
  summary: Rumale is a machine learning library in Ruby. Rumale provides machine learning
218
219
  algorithms with interfaces similar to Scikit-Learn in Python.