rumale 0.19.0 → 0.19.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/lib/rumale.rb +1 -0
- data/lib/rumale/clustering/mini_batch_k_means.rb +139 -0
- data/lib/rumale/neural_network/adam.rb +1 -1
- data/lib/rumale/version.rb +1 -1
- metadata +6 -5
checksums.yaml
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
---
|
|
2
2
|
SHA256:
|
|
3
|
-
metadata.gz:
|
|
4
|
-
data.tar.gz:
|
|
3
|
+
metadata.gz: f49170105721cfebcae9f1a424e9a858650d78225541a8cb63b0ad4c70734988
|
|
4
|
+
data.tar.gz: ecc35086328eee1066252e75b8cd638256039e93beebc0bce5714493fe72570b
|
|
5
5
|
SHA512:
|
|
6
|
-
metadata.gz:
|
|
7
|
-
data.tar.gz:
|
|
6
|
+
metadata.gz: 68f432bb34ff6c8e467a91d7c7e3aa07e816c2dd8807defc9e4e82e7a720c925062dbd27c8a7ec3294ecef2d71041baead2510edaf03a1eee210dc811eede22d
|
|
7
|
+
data.tar.gz: 5854eacc12de6c3cdcdbab0f9b4e73fc64d1be0533732348da6b4d6dcb0be9f115e2415501b05148fd021fa844ac0c25adc1bb858432a02ca6fe19d30a3538c7
|
data/CHANGELOG.md
CHANGED
|
@@ -1,3 +1,7 @@
|
|
|
1
|
+
# 0.19.1
|
|
2
|
+
- Add cluster analysis class for [mini-batch K-Means](https://yoshoku.github.io/rumale/doc/Rumale/Clustering/MiniBatchKMeans.html).
|
|
3
|
+
- Fix some typos.
|
|
4
|
+
|
|
1
5
|
# 0.19.0
|
|
2
6
|
- Change mmh3 and mopti gem to non-runtime dependent library.
|
|
3
7
|
- The mmh3 gem is used in [FeatureHasher](https://yoshoku.github.io/rumale/doc/Rumale/FeatureExtraction/FeatureHasher.html).
|
data/lib/rumale.rb
CHANGED
|
@@ -70,6 +70,7 @@ require 'rumale/ensemble/random_forest_regressor'
|
|
|
70
70
|
require 'rumale/ensemble/extra_trees_classifier'
|
|
71
71
|
require 'rumale/ensemble/extra_trees_regressor'
|
|
72
72
|
require 'rumale/clustering/k_means'
|
|
73
|
+
require 'rumale/clustering/mini_batch_k_means'
|
|
73
74
|
require 'rumale/clustering/k_medoids'
|
|
74
75
|
require 'rumale/clustering/gaussian_mixture'
|
|
75
76
|
require 'rumale/clustering/dbscan'
|
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require 'rumale/base/base_estimator'
|
|
4
|
+
require 'rumale/base/cluster_analyzer'
|
|
5
|
+
require 'rumale/pairwise_metric'
|
|
6
|
+
|
|
7
|
+
module Rumale
|
|
8
|
+
module Clustering
|
|
9
|
+
# MniBatchKMeans is a class that implements K-Means cluster analysis
|
|
10
|
+
# with mini-batch stochastic gradient descent (SGD).
|
|
11
|
+
#
|
|
12
|
+
# @example
|
|
13
|
+
# analyzer = Rumale::Clustering::MiniBatchKMeans.new(n_clusters: 10, max_iter: 50, batch_size: 50, random_seed: 1)
|
|
14
|
+
# cluster_labels = analyzer.fit_predict(samples)
|
|
15
|
+
#
|
|
16
|
+
# *Reference*
|
|
17
|
+
# - Sculley, D., "Web-scale k-means clustering," Proc. WWW'10, pp. 1177--1178, 2010.
|
|
18
|
+
class MiniBatchKMeans
|
|
19
|
+
include Base::BaseEstimator
|
|
20
|
+
include Base::ClusterAnalyzer
|
|
21
|
+
|
|
22
|
+
# Return the centroids.
|
|
23
|
+
# @return [Numo::DFloat] (shape: [n_clusters, n_features])
|
|
24
|
+
attr_reader :cluster_centers
|
|
25
|
+
|
|
26
|
+
# Return the random generator.
|
|
27
|
+
# @return [Random]
|
|
28
|
+
attr_reader :rng
|
|
29
|
+
|
|
30
|
+
# Create a new cluster analyzer with K-Means method with mini-batch SGD.
|
|
31
|
+
#
|
|
32
|
+
# @param n_clusters [Integer] The number of clusters.
|
|
33
|
+
# @param init [String] The initialization method for centroids ('random' or 'k-means++').
|
|
34
|
+
# @param max_iter [Integer] The maximum number of iterations.
|
|
35
|
+
# @param batch_size [Integer] The size of the mini batches.
|
|
36
|
+
# @param tol [Float] The tolerance of termination criterion.
|
|
37
|
+
# @param random_seed [Integer] The seed value using to initialize the random generator.
|
|
38
|
+
def initialize(n_clusters: 8, init: 'k-means++', max_iter: 100, batch_size: 100, tol: 1.0e-4, random_seed: nil)
|
|
39
|
+
check_params_numeric(n_clusters: n_clusters, max_iter: max_iter, batch_size: batch_size, tol: tol)
|
|
40
|
+
check_params_string(init: init)
|
|
41
|
+
check_params_numeric_or_nil(random_seed: random_seed)
|
|
42
|
+
check_params_positive(n_clusters: n_clusters, max_iter: max_iter)
|
|
43
|
+
@params = {}
|
|
44
|
+
@params[:n_clusters] = n_clusters
|
|
45
|
+
@params[:init] = init == 'random' ? 'random' : 'k-means++'
|
|
46
|
+
@params[:max_iter] = max_iter
|
|
47
|
+
@params[:batch_size] = batch_size
|
|
48
|
+
@params[:tol] = tol
|
|
49
|
+
@params[:random_seed] = random_seed
|
|
50
|
+
@params[:random_seed] ||= srand
|
|
51
|
+
@cluster_centers = nil
|
|
52
|
+
@rng = Random.new(@params[:random_seed])
|
|
53
|
+
end
|
|
54
|
+
|
|
55
|
+
# Analysis clusters with given training data.
|
|
56
|
+
#
|
|
57
|
+
# @overload fit(x) -> MiniBatchKMeans
|
|
58
|
+
#
|
|
59
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
|
|
60
|
+
# @return [KMeans] The learned cluster analyzer itself.
|
|
61
|
+
def fit(x, _y = nil)
|
|
62
|
+
x = check_convert_sample_array(x)
|
|
63
|
+
# initialization.
|
|
64
|
+
n_samples = x.shape[0]
|
|
65
|
+
update_counter = Numo::Int32.zeros(@params[:n_clusters])
|
|
66
|
+
sub_rng = @rng.dup
|
|
67
|
+
init_cluster_centers(x, sub_rng)
|
|
68
|
+
# optimization with mini-batch sgd.
|
|
69
|
+
@params[:max_iter].times do |_t|
|
|
70
|
+
sample_ids = [*0...n_samples].shuffle(random: sub_rng)
|
|
71
|
+
old_centers = @cluster_centers.dup
|
|
72
|
+
until (subset_ids = sample_ids.shift(@params[:batch_size])).empty?
|
|
73
|
+
# sub sampling
|
|
74
|
+
sub_x = x[subset_ids, true]
|
|
75
|
+
# assign nearest centroids
|
|
76
|
+
cluster_labels = assign_cluster(sub_x)
|
|
77
|
+
# update centroids
|
|
78
|
+
@params[:n_clusters].times do |c|
|
|
79
|
+
assigned_bits = cluster_labels.eq(c)
|
|
80
|
+
next unless assigned_bits.count.positive?
|
|
81
|
+
|
|
82
|
+
update_counter[c] += 1
|
|
83
|
+
learning_rate = 1.fdiv(update_counter[c])
|
|
84
|
+
update = sub_x[assigned_bits.where, true].mean(axis: 0)
|
|
85
|
+
@cluster_centers[c, true] = (1 - learning_rate) * @cluster_centers[c, true] + learning_rate * update
|
|
86
|
+
end
|
|
87
|
+
end
|
|
88
|
+
error = Numo::NMath.sqrt(((old_centers - @cluster_centers)**2).sum(axis: 1)).mean
|
|
89
|
+
break if error <= @params[:tol]
|
|
90
|
+
end
|
|
91
|
+
self
|
|
92
|
+
end
|
|
93
|
+
|
|
94
|
+
# Predict cluster labels for samples.
|
|
95
|
+
#
|
|
96
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the cluster label.
|
|
97
|
+
# @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
|
|
98
|
+
def predict(x)
|
|
99
|
+
x = check_convert_sample_array(x)
|
|
100
|
+
assign_cluster(x)
|
|
101
|
+
end
|
|
102
|
+
|
|
103
|
+
# Analysis clusters and assign samples to clusters.
|
|
104
|
+
#
|
|
105
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
|
|
106
|
+
# @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
|
|
107
|
+
def fit_predict(x)
|
|
108
|
+
x = check_convert_sample_array(x)
|
|
109
|
+
fit(x)
|
|
110
|
+
predict(x)
|
|
111
|
+
end
|
|
112
|
+
|
|
113
|
+
private
|
|
114
|
+
|
|
115
|
+
def assign_cluster(x)
|
|
116
|
+
distance_matrix = PairwiseMetric.euclidean_distance(x, @cluster_centers)
|
|
117
|
+
distance_matrix.min_index(axis: 1) - Numo::Int32[*0.step(distance_matrix.size - 1, @cluster_centers.shape[0])]
|
|
118
|
+
end
|
|
119
|
+
|
|
120
|
+
def init_cluster_centers(x, sub_rng)
|
|
121
|
+
# random initialize
|
|
122
|
+
n_samples = x.shape[0]
|
|
123
|
+
rand_id = [*0...n_samples].sample(@params[:n_clusters], random: sub_rng)
|
|
124
|
+
@cluster_centers = x[rand_id, true].dup
|
|
125
|
+
return unless @params[:init] == 'k-means++'
|
|
126
|
+
|
|
127
|
+
# k-means++ initialize
|
|
128
|
+
(1...@params[:n_clusters]).each do |n|
|
|
129
|
+
distance_matrix = PairwiseMetric.euclidean_distance(x, @cluster_centers[0...n, true])
|
|
130
|
+
min_distances = distance_matrix.flatten[distance_matrix.min_index(axis: 1)]
|
|
131
|
+
probs = min_distances**2 / (min_distances**2).sum
|
|
132
|
+
cum_probs = probs.cumsum
|
|
133
|
+
selected_id = cum_probs.gt(sub_rng.rand).where.to_a.first
|
|
134
|
+
@cluster_centers[n, true] = x[selected_id, true].dup
|
|
135
|
+
end
|
|
136
|
+
end
|
|
137
|
+
end
|
|
138
|
+
end
|
|
139
|
+
end
|
|
@@ -32,7 +32,7 @@ module Rumale
|
|
|
32
32
|
end
|
|
33
33
|
|
|
34
34
|
# @!visibility private
|
|
35
|
-
# Calculate the updated weight with
|
|
35
|
+
# Calculate the updated weight with Adam adaptive learning rate.
|
|
36
36
|
#
|
|
37
37
|
# @param weight [Numo::DFloat] (shape: [n_features]) The weight to be updated.
|
|
38
38
|
# @param gradient [Numo::DFloat] (shape: [n_features]) The gradient for updating the weight.
|
data/lib/rumale/version.rb
CHANGED
metadata
CHANGED
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
|
2
2
|
name: rumale
|
|
3
3
|
version: !ruby/object:Gem::Version
|
|
4
|
-
version: 0.19.
|
|
4
|
+
version: 0.19.1
|
|
5
5
|
platform: ruby
|
|
6
6
|
authors:
|
|
7
7
|
- yoshoku
|
|
8
|
-
autorequire:
|
|
8
|
+
autorequire:
|
|
9
9
|
bindir: exe
|
|
10
10
|
cert_chain: []
|
|
11
|
-
date: 2020-
|
|
11
|
+
date: 2020-06-06 00:00:00.000000000 Z
|
|
12
12
|
dependencies:
|
|
13
13
|
- !ruby/object:Gem::Dependency
|
|
14
14
|
name: numo-narray
|
|
@@ -72,6 +72,7 @@ files:
|
|
|
72
72
|
- lib/rumale/clustering/hdbscan.rb
|
|
73
73
|
- lib/rumale/clustering/k_means.rb
|
|
74
74
|
- lib/rumale/clustering/k_medoids.rb
|
|
75
|
+
- lib/rumale/clustering/mini_batch_k_means.rb
|
|
75
76
|
- lib/rumale/clustering/power_iteration.rb
|
|
76
77
|
- lib/rumale/clustering/single_linkage.rb
|
|
77
78
|
- lib/rumale/clustering/snn.rb
|
|
@@ -196,7 +197,7 @@ metadata:
|
|
|
196
197
|
source_code_uri: https://github.com/yoshoku/rumale
|
|
197
198
|
documentation_uri: https://yoshoku.github.io/rumale/doc/
|
|
198
199
|
bug_tracker_uri: https://github.com/yoshoku/rumale/issues
|
|
199
|
-
post_install_message:
|
|
200
|
+
post_install_message:
|
|
200
201
|
rdoc_options: []
|
|
201
202
|
require_paths:
|
|
202
203
|
- lib
|
|
@@ -212,7 +213,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
|
212
213
|
version: '0'
|
|
213
214
|
requirements: []
|
|
214
215
|
rubygems_version: 3.1.2
|
|
215
|
-
signing_key:
|
|
216
|
+
signing_key:
|
|
216
217
|
specification_version: 4
|
|
217
218
|
summary: Rumale is a machine learning library in Ruby. Rumale provides machine learning
|
|
218
219
|
algorithms with interfaces similar to Scikit-Learn in Python.
|