rumale 0.19.0 → 0.19.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 1cd1cdc16e6c72743d064db7d254c74eb98ca33e97cfdf9e8a76cc1fbe5dd29b
4
- data.tar.gz: 2077cae2629f2c403cc0afc415dc4b4151a2eac2ab7a3230402bf761bb653829
3
+ metadata.gz: f49170105721cfebcae9f1a424e9a858650d78225541a8cb63b0ad4c70734988
4
+ data.tar.gz: ecc35086328eee1066252e75b8cd638256039e93beebc0bce5714493fe72570b
5
5
  SHA512:
6
- metadata.gz: 2bbcdce6d0a31c95500a81a7d4a55407786068fac65ce1e5ede1bc3f56d97b2ec93fd9ca1dc52fc1a24782dba469099b25d0398a5993716da011851f18f8179c
7
- data.tar.gz: cc9fc19ea73dfa76e8ede18df7cb57f931cccdea2c546f414746ed681afbd272e3866913c75534bbdce6b275d057baa2c793b4046c8e46f697e30d5b87dba066
6
+ metadata.gz: 68f432bb34ff6c8e467a91d7c7e3aa07e816c2dd8807defc9e4e82e7a720c925062dbd27c8a7ec3294ecef2d71041baead2510edaf03a1eee210dc811eede22d
7
+ data.tar.gz: 5854eacc12de6c3cdcdbab0f9b4e73fc64d1be0533732348da6b4d6dcb0be9f115e2415501b05148fd021fa844ac0c25adc1bb858432a02ca6fe19d30a3538c7
@@ -1,3 +1,7 @@
1
+ # 0.19.1
2
+ - Add cluster analysis class for [mini-batch K-Means](https://yoshoku.github.io/rumale/doc/Rumale/Clustering/MiniBatchKMeans.html).
3
+ - Fix some typos.
4
+
1
5
  # 0.19.0
2
6
  - Change mmh3 and mopti gem to non-runtime dependent library.
3
7
  - The mmh3 gem is used in [FeatureHasher](https://yoshoku.github.io/rumale/doc/Rumale/FeatureExtraction/FeatureHasher.html).
@@ -70,6 +70,7 @@ require 'rumale/ensemble/random_forest_regressor'
70
70
  require 'rumale/ensemble/extra_trees_classifier'
71
71
  require 'rumale/ensemble/extra_trees_regressor'
72
72
  require 'rumale/clustering/k_means'
73
+ require 'rumale/clustering/mini_batch_k_means'
73
74
  require 'rumale/clustering/k_medoids'
74
75
  require 'rumale/clustering/gaussian_mixture'
75
76
  require 'rumale/clustering/dbscan'
@@ -0,0 +1,139 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/base_estimator'
4
+ require 'rumale/base/cluster_analyzer'
5
+ require 'rumale/pairwise_metric'
6
+
7
+ module Rumale
8
+ module Clustering
9
+ # MniBatchKMeans is a class that implements K-Means cluster analysis
10
+ # with mini-batch stochastic gradient descent (SGD).
11
+ #
12
+ # @example
13
+ # analyzer = Rumale::Clustering::MiniBatchKMeans.new(n_clusters: 10, max_iter: 50, batch_size: 50, random_seed: 1)
14
+ # cluster_labels = analyzer.fit_predict(samples)
15
+ #
16
+ # *Reference*
17
+ # - Sculley, D., "Web-scale k-means clustering," Proc. WWW'10, pp. 1177--1178, 2010.
18
+ class MiniBatchKMeans
19
+ include Base::BaseEstimator
20
+ include Base::ClusterAnalyzer
21
+
22
+ # Return the centroids.
23
+ # @return [Numo::DFloat] (shape: [n_clusters, n_features])
24
+ attr_reader :cluster_centers
25
+
26
+ # Return the random generator.
27
+ # @return [Random]
28
+ attr_reader :rng
29
+
30
+ # Create a new cluster analyzer with K-Means method with mini-batch SGD.
31
+ #
32
+ # @param n_clusters [Integer] The number of clusters.
33
+ # @param init [String] The initialization method for centroids ('random' or 'k-means++').
34
+ # @param max_iter [Integer] The maximum number of iterations.
35
+ # @param batch_size [Integer] The size of the mini batches.
36
+ # @param tol [Float] The tolerance of termination criterion.
37
+ # @param random_seed [Integer] The seed value using to initialize the random generator.
38
+ def initialize(n_clusters: 8, init: 'k-means++', max_iter: 100, batch_size: 100, tol: 1.0e-4, random_seed: nil)
39
+ check_params_numeric(n_clusters: n_clusters, max_iter: max_iter, batch_size: batch_size, tol: tol)
40
+ check_params_string(init: init)
41
+ check_params_numeric_or_nil(random_seed: random_seed)
42
+ check_params_positive(n_clusters: n_clusters, max_iter: max_iter)
43
+ @params = {}
44
+ @params[:n_clusters] = n_clusters
45
+ @params[:init] = init == 'random' ? 'random' : 'k-means++'
46
+ @params[:max_iter] = max_iter
47
+ @params[:batch_size] = batch_size
48
+ @params[:tol] = tol
49
+ @params[:random_seed] = random_seed
50
+ @params[:random_seed] ||= srand
51
+ @cluster_centers = nil
52
+ @rng = Random.new(@params[:random_seed])
53
+ end
54
+
55
+ # Analysis clusters with given training data.
56
+ #
57
+ # @overload fit(x) -> MiniBatchKMeans
58
+ #
59
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
60
+ # @return [KMeans] The learned cluster analyzer itself.
61
+ def fit(x, _y = nil)
62
+ x = check_convert_sample_array(x)
63
+ # initialization.
64
+ n_samples = x.shape[0]
65
+ update_counter = Numo::Int32.zeros(@params[:n_clusters])
66
+ sub_rng = @rng.dup
67
+ init_cluster_centers(x, sub_rng)
68
+ # optimization with mini-batch sgd.
69
+ @params[:max_iter].times do |_t|
70
+ sample_ids = [*0...n_samples].shuffle(random: sub_rng)
71
+ old_centers = @cluster_centers.dup
72
+ until (subset_ids = sample_ids.shift(@params[:batch_size])).empty?
73
+ # sub sampling
74
+ sub_x = x[subset_ids, true]
75
+ # assign nearest centroids
76
+ cluster_labels = assign_cluster(sub_x)
77
+ # update centroids
78
+ @params[:n_clusters].times do |c|
79
+ assigned_bits = cluster_labels.eq(c)
80
+ next unless assigned_bits.count.positive?
81
+
82
+ update_counter[c] += 1
83
+ learning_rate = 1.fdiv(update_counter[c])
84
+ update = sub_x[assigned_bits.where, true].mean(axis: 0)
85
+ @cluster_centers[c, true] = (1 - learning_rate) * @cluster_centers[c, true] + learning_rate * update
86
+ end
87
+ end
88
+ error = Numo::NMath.sqrt(((old_centers - @cluster_centers)**2).sum(axis: 1)).mean
89
+ break if error <= @params[:tol]
90
+ end
91
+ self
92
+ end
93
+
94
+ # Predict cluster labels for samples.
95
+ #
96
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the cluster label.
97
+ # @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
98
+ def predict(x)
99
+ x = check_convert_sample_array(x)
100
+ assign_cluster(x)
101
+ end
102
+
103
+ # Analysis clusters and assign samples to clusters.
104
+ #
105
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
106
+ # @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
107
+ def fit_predict(x)
108
+ x = check_convert_sample_array(x)
109
+ fit(x)
110
+ predict(x)
111
+ end
112
+
113
+ private
114
+
115
+ def assign_cluster(x)
116
+ distance_matrix = PairwiseMetric.euclidean_distance(x, @cluster_centers)
117
+ distance_matrix.min_index(axis: 1) - Numo::Int32[*0.step(distance_matrix.size - 1, @cluster_centers.shape[0])]
118
+ end
119
+
120
+ def init_cluster_centers(x, sub_rng)
121
+ # random initialize
122
+ n_samples = x.shape[0]
123
+ rand_id = [*0...n_samples].sample(@params[:n_clusters], random: sub_rng)
124
+ @cluster_centers = x[rand_id, true].dup
125
+ return unless @params[:init] == 'k-means++'
126
+
127
+ # k-means++ initialize
128
+ (1...@params[:n_clusters]).each do |n|
129
+ distance_matrix = PairwiseMetric.euclidean_distance(x, @cluster_centers[0...n, true])
130
+ min_distances = distance_matrix.flatten[distance_matrix.min_index(axis: 1)]
131
+ probs = min_distances**2 / (min_distances**2).sum
132
+ cum_probs = probs.cumsum
133
+ selected_id = cum_probs.gt(sub_rng.rand).where.to_a.first
134
+ @cluster_centers[n, true] = x[selected_id, true].dup
135
+ end
136
+ end
137
+ end
138
+ end
139
+ end
@@ -32,7 +32,7 @@ module Rumale
32
32
  end
33
33
 
34
34
  # @!visibility private
35
- # Calculate the updated weight with Nadam adaptive learning rate.
35
+ # Calculate the updated weight with Adam adaptive learning rate.
36
36
  #
37
37
  # @param weight [Numo::DFloat] (shape: [n_features]) The weight to be updated.
38
38
  # @param gradient [Numo::DFloat] (shape: [n_features]) The gradient for updating the weight.
@@ -3,5 +3,5 @@
3
3
  # Rumale is a machine learning library in Ruby.
4
4
  module Rumale
5
5
  # The version of Rumale you are using.
6
- VERSION = '0.19.0'
6
+ VERSION = '0.19.1'
7
7
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: rumale
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.19.0
4
+ version: 0.19.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
8
- autorequire:
8
+ autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2020-05-23 00:00:00.000000000 Z
11
+ date: 2020-06-06 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -72,6 +72,7 @@ files:
72
72
  - lib/rumale/clustering/hdbscan.rb
73
73
  - lib/rumale/clustering/k_means.rb
74
74
  - lib/rumale/clustering/k_medoids.rb
75
+ - lib/rumale/clustering/mini_batch_k_means.rb
75
76
  - lib/rumale/clustering/power_iteration.rb
76
77
  - lib/rumale/clustering/single_linkage.rb
77
78
  - lib/rumale/clustering/snn.rb
@@ -196,7 +197,7 @@ metadata:
196
197
  source_code_uri: https://github.com/yoshoku/rumale
197
198
  documentation_uri: https://yoshoku.github.io/rumale/doc/
198
199
  bug_tracker_uri: https://github.com/yoshoku/rumale/issues
199
- post_install_message:
200
+ post_install_message:
200
201
  rdoc_options: []
201
202
  require_paths:
202
203
  - lib
@@ -212,7 +213,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
212
213
  version: '0'
213
214
  requirements: []
214
215
  rubygems_version: 3.1.2
215
- signing_key:
216
+ signing_key:
216
217
  specification_version: 4
217
218
  summary: Rumale is a machine learning library in Ruby. Rumale provides machine learning
218
219
  algorithms with interfaces similar to Scikit-Learn in Python.