rumale 0.18.5 → 0.19.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.rubocop.yml +15 -3
- data/.travis.yml +3 -3
- data/CHANGELOG.md +44 -0
- data/Gemfile +9 -0
- data/README.md +6 -44
- data/lib/rumale.rb +3 -0
- data/lib/rumale/base/base_estimator.rb +2 -0
- data/lib/rumale/clustering/dbscan.rb +5 -1
- data/lib/rumale/clustering/gaussian_mixture.rb +2 -0
- data/lib/rumale/clustering/hdbscan.rb +5 -3
- data/lib/rumale/clustering/k_means.rb +2 -1
- data/lib/rumale/clustering/k_medoids.rb +5 -1
- data/lib/rumale/clustering/mini_batch_k_means.rb +139 -0
- data/lib/rumale/clustering/power_iteration.rb +3 -1
- data/lib/rumale/clustering/single_linkage.rb +3 -1
- data/lib/rumale/clustering/snn.rb +2 -2
- data/lib/rumale/clustering/spectral_clustering.rb +2 -2
- data/lib/rumale/dataset.rb +2 -0
- data/lib/rumale/decomposition/factor_analysis.rb +3 -1
- data/lib/rumale/decomposition/fast_ica.rb +2 -2
- data/lib/rumale/decomposition/nmf.rb +1 -1
- data/lib/rumale/decomposition/pca.rb +25 -6
- data/lib/rumale/ensemble/ada_boost_classifier.rb +4 -1
- data/lib/rumale/ensemble/ada_boost_regressor.rb +4 -2
- data/lib/rumale/ensemble/extra_trees_classifier.rb +1 -1
- data/lib/rumale/ensemble/extra_trees_regressor.rb +1 -1
- data/lib/rumale/ensemble/gradient_boosting_classifier.rb +4 -4
- data/lib/rumale/ensemble/gradient_boosting_regressor.rb +4 -4
- data/lib/rumale/evaluation_measure/adjusted_rand_score.rb +1 -1
- data/lib/rumale/evaluation_measure/calinski_harabasz_score.rb +1 -1
- data/lib/rumale/evaluation_measure/davies_bouldin_score.rb +1 -1
- data/lib/rumale/evaluation_measure/function.rb +2 -1
- data/lib/rumale/evaluation_measure/mutual_information.rb +1 -1
- data/lib/rumale/evaluation_measure/normalized_mutual_information.rb +4 -2
- data/lib/rumale/evaluation_measure/precision_recall.rb +5 -0
- data/lib/rumale/evaluation_measure/purity.rb +1 -1
- data/lib/rumale/evaluation_measure/roc_auc.rb +3 -0
- data/lib/rumale/evaluation_measure/silhouette_score.rb +3 -1
- data/lib/rumale/feature_extraction/feature_hasher.rb +14 -1
- data/lib/rumale/feature_extraction/hash_vectorizer.rb +1 -0
- data/lib/rumale/feature_extraction/tfidf_transformer.rb +113 -0
- data/lib/rumale/kernel_approximation/nystroem.rb +1 -1
- data/lib/rumale/kernel_approximation/rbf.rb +1 -1
- data/lib/rumale/kernel_machine/kernel_fda.rb +1 -1
- data/lib/rumale/kernel_machine/kernel_pca.rb +1 -1
- data/lib/rumale/kernel_machine/kernel_ridge.rb +2 -0
- data/lib/rumale/kernel_machine/kernel_svc.rb +1 -1
- data/lib/rumale/linear_model/base_linear_model.rb +2 -0
- data/lib/rumale/linear_model/elastic_net.rb +3 -3
- data/lib/rumale/linear_model/lasso.rb +3 -3
- data/lib/rumale/linear_model/linear_regression.rb +2 -1
- data/lib/rumale/linear_model/logistic_regression.rb +3 -3
- data/lib/rumale/linear_model/ridge.rb +2 -1
- data/lib/rumale/linear_model/svc.rb +3 -3
- data/lib/rumale/linear_model/svr.rb +3 -3
- data/lib/rumale/manifold/mds.rb +3 -1
- data/lib/rumale/manifold/tsne.rb +6 -2
- data/lib/rumale/metric_learning/neighbourhood_component_analysis.rb +14 -1
- data/lib/rumale/model_selection/grid_search_cv.rb +1 -0
- data/lib/rumale/naive_bayes/bernoulli_nb.rb +1 -1
- data/lib/rumale/naive_bayes/multinomial_nb.rb +1 -1
- data/lib/rumale/nearest_neighbors/k_neighbors_classifier.rb +1 -0
- data/lib/rumale/nearest_neighbors/k_neighbors_regressor.rb +2 -0
- data/lib/rumale/nearest_neighbors/vp_tree.rb +1 -1
- data/lib/rumale/neural_network/adam.rb +2 -2
- data/lib/rumale/neural_network/base_mlp.rb +1 -0
- data/lib/rumale/optimizer/ada_grad.rb +4 -1
- data/lib/rumale/optimizer/adam.rb +4 -1
- data/lib/rumale/optimizer/nadam.rb +6 -1
- data/lib/rumale/optimizer/rmsprop.rb +5 -2
- data/lib/rumale/optimizer/sgd.rb +3 -0
- data/lib/rumale/optimizer/yellow_fin.rb +4 -1
- data/lib/rumale/pipeline/pipeline.rb +3 -0
- data/lib/rumale/polynomial_model/base_factorization_machine.rb +5 -0
- data/lib/rumale/polynomial_model/factorization_machine_classifier.rb +7 -2
- data/lib/rumale/polynomial_model/factorization_machine_regressor.rb +7 -2
- data/lib/rumale/preprocessing/l1_normalizer.rb +62 -0
- data/lib/rumale/preprocessing/l2_normalizer.rb +2 -1
- data/lib/rumale/preprocessing/one_hot_encoder.rb +3 -0
- data/lib/rumale/preprocessing/ordinal_encoder.rb +2 -0
- data/lib/rumale/preprocessing/polynomial_features.rb +1 -0
- data/lib/rumale/probabilistic_output.rb +4 -2
- data/lib/rumale/tree/base_decision_tree.rb +2 -0
- data/lib/rumale/tree/decision_tree_classifier.rb +1 -0
- data/lib/rumale/tree/extra_tree_classifier.rb +1 -1
- data/lib/rumale/tree/extra_tree_regressor.rb +1 -1
- data/lib/rumale/tree/gradient_tree_regressor.rb +5 -5
- data/lib/rumale/utils.rb +1 -0
- data/lib/rumale/validation.rb +7 -0
- data/lib/rumale/version.rb +1 -1
- data/rumale.gemspec +1 -13
- metadata +10 -133
checksums.yaml
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
---
|
|
2
2
|
SHA256:
|
|
3
|
-
metadata.gz:
|
|
4
|
-
data.tar.gz:
|
|
3
|
+
metadata.gz: 1bff2e1e6182aa954be00ed107ed1bd81220298f89514b4b31304f8890ff27c4
|
|
4
|
+
data.tar.gz: '09b185f468baf9dbec6280fa6c06984c95919308f1d2247277bf30348ed392bc'
|
|
5
5
|
SHA512:
|
|
6
|
-
metadata.gz:
|
|
7
|
-
data.tar.gz:
|
|
6
|
+
metadata.gz: 6d8f1fcaffcd6714c6156fc615d87e6b6950e82ab40fc7434cfc5a014d6c08eb0170ee7c45d8fed978c2a52f839b1ce647fd6e088cbab2ea45e517b34c88407a
|
|
7
|
+
data.tar.gz: b255ae4c24cdc91ebad59f79ee5a58c5d2a5ffa79bda0ac221e3a33bd824d2fd94e5cd83f3a06e54a2dc537a074276cea5a71651deeee2a304d23e963ff92c9d
|
data/.rubocop.yml
CHANGED
|
@@ -3,11 +3,11 @@ require:
|
|
|
3
3
|
- rubocop-rspec
|
|
4
4
|
|
|
5
5
|
AllCops:
|
|
6
|
-
TargetRubyVersion: 2.
|
|
6
|
+
TargetRubyVersion: 2.5
|
|
7
7
|
DisplayCopNames: true
|
|
8
8
|
DisplayStyleGuide: true
|
|
9
9
|
Exclude:
|
|
10
|
-
- '
|
|
10
|
+
- 'ext/rumale/extconf.rb'
|
|
11
11
|
- 'rumale.gemspec'
|
|
12
12
|
- 'Rakefile'
|
|
13
13
|
- 'Gemfile'
|
|
@@ -70,14 +70,26 @@ Naming/MethodParameterName:
|
|
|
70
70
|
Naming/ConstantName:
|
|
71
71
|
Enabled: false
|
|
72
72
|
|
|
73
|
+
Style/ExponentialNotation:
|
|
74
|
+
Enabled: true
|
|
75
|
+
|
|
73
76
|
Style/FormatStringToken:
|
|
74
77
|
Enabled: false
|
|
75
78
|
|
|
76
79
|
Style/NumericLiterals:
|
|
77
80
|
Enabled: false
|
|
78
81
|
|
|
82
|
+
Style/SlicingWithRange:
|
|
83
|
+
Enabled: true
|
|
84
|
+
|
|
79
85
|
Layout/EmptyLineAfterGuardClause:
|
|
80
|
-
Enabled:
|
|
86
|
+
Enabled: true
|
|
87
|
+
|
|
88
|
+
Layout/EmptyLinesAroundAttributeAccessor:
|
|
89
|
+
Enabled: true
|
|
90
|
+
|
|
91
|
+
Layout/SpaceAroundMethodCallOperator:
|
|
92
|
+
Enabled: true
|
|
81
93
|
|
|
82
94
|
RSpec/MultipleExpectations:
|
|
83
95
|
Enabled: false
|
data/.travis.yml
CHANGED
data/CHANGELOG.md
CHANGED
|
@@ -1,3 +1,47 @@
|
|
|
1
|
+
# 0.19.2
|
|
2
|
+
- Fix L2Normalizer to avoid zero divide.
|
|
3
|
+
- Add preprocssing class for [L1Normalizer](https://yoshoku.github.io/rumale/doc/Rumale/Preprocessing/L1Normalizer.html).
|
|
4
|
+
- Add transformer class for [TfidfTransformer](https://yoshoku.github.io/rumale/doc/Rumale/FeatureExtraction/TfidfTransformer.html).
|
|
5
|
+
|
|
6
|
+
# 0.19.1
|
|
7
|
+
- Add cluster analysis class for [mini-batch K-Means](https://yoshoku.github.io/rumale/doc/Rumale/Clustering/MiniBatchKMeans.html).
|
|
8
|
+
- Fix some typos.
|
|
9
|
+
|
|
10
|
+
# 0.19.0
|
|
11
|
+
- Change mmh3 and mopti gem to non-runtime dependent library.
|
|
12
|
+
- The mmh3 gem is used in [FeatureHasher](https://yoshoku.github.io/rumale/doc/Rumale/FeatureExtraction/FeatureHasher.html).
|
|
13
|
+
You only need to require mmh3 gem when using FeatureHasher.
|
|
14
|
+
```ruby
|
|
15
|
+
require 'mmh3'
|
|
16
|
+
require 'rumale'
|
|
17
|
+
|
|
18
|
+
encoder = Rumale::FeatureExtraction::FeatureHasher.new
|
|
19
|
+
```
|
|
20
|
+
- The mopti gem is used in [NeighbourhoodComponentAnalysis](https://yoshoku.github.io/rumale/doc/Rumale/MetricLearning/NeighbourhoodComponentAnalysis.html).
|
|
21
|
+
You only need to require mopti gem when using NeighbourhoodComponentAnalysis.
|
|
22
|
+
```ruby
|
|
23
|
+
require 'mopti'
|
|
24
|
+
require 'rumale'
|
|
25
|
+
|
|
26
|
+
transformer = Rumale::MetricLearning::NeighbourhoodComponentAnalysis.new
|
|
27
|
+
```
|
|
28
|
+
- Change the default value of solver parameter on [PCA](https://yoshoku.github.io/rumale/doc/Rumale/Decomposition/PCA.html) to 'auto'.
|
|
29
|
+
If Numo::Linalg is loaded, 'evd' is selected for the solver, otherwise 'fpt' is selected.
|
|
30
|
+
- Deprecate [PolynomialModel](https://yoshoku.github.io/rumale/doc/Rumale/PolynomialModel.html), [Optimizer](https://yoshoku.github.io/rumale/doc/Rumale/Optimizer.html), and the estimators contained in them. They will be deleted in version 0.20.0.
|
|
31
|
+
- Many machine learning libraries do not contain factorization machine algorithms, they are provided by another compatible library.
|
|
32
|
+
In addition, there are no plans to implement estimators in PolynomialModel.
|
|
33
|
+
Thus, the author decided to deprecate PolynomialModel.
|
|
34
|
+
- Currently, the Optimizer classes are only used by PolynomialModel estimators.
|
|
35
|
+
Therefore, they have been deprecated together with PolynomialModel.
|
|
36
|
+
|
|
37
|
+
# 0.18.7
|
|
38
|
+
- Fix to convert target_name to string array in [classification_report method](https://yoshoku.github.io/rumale/doc/Rumale/EvaluationMeasure.html#classification_report-class_method).
|
|
39
|
+
- Refactor some codes with Rubocop.
|
|
40
|
+
|
|
41
|
+
# 0.18.6
|
|
42
|
+
- Fix some configuration files.
|
|
43
|
+
- Update API documentation.
|
|
44
|
+
|
|
1
45
|
# 0.18.5
|
|
2
46
|
- Add functions for calculation of cosine similarity and distance to [Rumale::PairwiseMetric](https://yoshoku.github.io/rumale/doc/Rumale/PairwiseMetric.html).
|
|
3
47
|
- Refactor some codes with Rubocop.
|
data/Gemfile
CHANGED
|
@@ -2,3 +2,12 @@ source 'https://rubygems.org'
|
|
|
2
2
|
|
|
3
3
|
# Specify your gem's dependencies in rumale.gemspec
|
|
4
4
|
gemspec
|
|
5
|
+
|
|
6
|
+
gem 'coveralls', '~> 0.8'
|
|
7
|
+
gem 'mmh3', '>= 1.0'
|
|
8
|
+
gem 'mopti', '>= 0.1.0'
|
|
9
|
+
gem 'numo-linalg', '>= 0.1.4'
|
|
10
|
+
gem 'parallel', '>= 1.17.0'
|
|
11
|
+
gem 'rake', '~> 12.0'
|
|
12
|
+
gem 'rake-compiler', '~> 1.0'
|
|
13
|
+
gem 'rspec', '~> 3.0'
|
data/README.md
CHANGED
|
@@ -11,7 +11,7 @@
|
|
|
11
11
|
Rumale (**Ru**by **ma**chine **le**arning) is a machine learning library in Ruby.
|
|
12
12
|
Rumale provides machine learning algorithms with interfaces similar to Scikit-Learn in Python.
|
|
13
13
|
Rumale supports Support Vector Machine,
|
|
14
|
-
Logistic Regression, Ridge, Lasso,
|
|
14
|
+
Logistic Regression, Ridge, Lasso,
|
|
15
15
|
Multi-layer Perceptron,
|
|
16
16
|
Naive Bayes, Decision Tree, Gradient Tree Boosting, Random Forest,
|
|
17
17
|
K-Means, Gaussian Mixture Model, DBSCAN, Spectral Clustering,
|
|
@@ -42,39 +42,7 @@ Or install it yourself as:
|
|
|
42
42
|
|
|
43
43
|
## Usage
|
|
44
44
|
|
|
45
|
-
### Example 1.
|
|
46
|
-
First, let's classify simple xor data.
|
|
47
|
-
|
|
48
|
-
```ruby
|
|
49
|
-
require 'rumale'
|
|
50
|
-
|
|
51
|
-
# Prepare XOR data.
|
|
52
|
-
samples = [[0, 0], [0, 1], [1, 0], [1, 1]]
|
|
53
|
-
labels = [0, 1, 1, 0]
|
|
54
|
-
|
|
55
|
-
# Train classifier with nearest neighbor rule.
|
|
56
|
-
estimator = Rumale::NearestNeighbors::KNeighborsClassifier.new(n_neighbors: 1)
|
|
57
|
-
estimator.fit(samples, labels)
|
|
58
|
-
|
|
59
|
-
# Predict labels.
|
|
60
|
-
p labels
|
|
61
|
-
p estimator.predict(samples)
|
|
62
|
-
```
|
|
63
|
-
|
|
64
|
-
Execution of the above script result in the following.
|
|
65
|
-
|
|
66
|
-
```ruby
|
|
67
|
-
[0, 1, 1, 0]
|
|
68
|
-
Numo::Int32#shape=[4]
|
|
69
|
-
[0, 1, 1, 0]
|
|
70
|
-
```
|
|
71
|
-
|
|
72
|
-
The basic usage of Rumale is to first train the model with the fit method
|
|
73
|
-
and then estimate with the predict method.
|
|
74
|
-
In addition, Rumale recommends using arrays such as feature vectors and labels with
|
|
75
|
-
[Numo::NArray](https://github.com/ruby-numo/numo-narray).
|
|
76
|
-
|
|
77
|
-
### Example 2. Pendigits dataset classification
|
|
45
|
+
### Example 1. Pendigits dataset classification
|
|
78
46
|
|
|
79
47
|
Rumale provides function loading libsvm format dataset file.
|
|
80
48
|
We start by downloading the pendigits dataset from LIBSVM Data web site.
|
|
@@ -137,7 +105,7 @@ $ ruby test.rb
|
|
|
137
105
|
Accuracy: 98.7%
|
|
138
106
|
```
|
|
139
107
|
|
|
140
|
-
### Example
|
|
108
|
+
### Example 2. Cross-validation
|
|
141
109
|
|
|
142
110
|
```ruby
|
|
143
111
|
require 'rumale'
|
|
@@ -168,7 +136,7 @@ $ ruby cross_validation.rb
|
|
|
168
136
|
5-CV mean log-loss: 0.355
|
|
169
137
|
```
|
|
170
138
|
|
|
171
|
-
### Example
|
|
139
|
+
### Example 3. Pipeline
|
|
172
140
|
|
|
173
141
|
```ruby
|
|
174
142
|
require 'rumale'
|
|
@@ -200,9 +168,10 @@ $ ruby pipeline.rb
|
|
|
200
168
|
5-CV mean accuracy: 99.6 %
|
|
201
169
|
```
|
|
202
170
|
|
|
203
|
-
##
|
|
171
|
+
## Speed up
|
|
204
172
|
|
|
205
173
|
### Numo::Linalg
|
|
174
|
+
Rumale uses [Numo::NArray](https://github.com/ruby-numo/numo-narray) for typed arrays.
|
|
206
175
|
Loading the [Numo::Linalg](https://github.com/ruby-numo/numo-linalg) allows to perform matrix product of Numo::NArray using BLAS libraries.
|
|
207
176
|
For example, using the [OpenBLAS](https://github.com/xianyi/OpenBLAS) speeds up many estimators in Rumale.
|
|
208
177
|
|
|
@@ -259,13 +228,6 @@ When -1 is given to n_jobs parameter, all processors are used.
|
|
|
259
228
|
estimator = Rumale::Ensemble::RandomForestClassifier.new(n_jobs: -1, random_seed: 1)
|
|
260
229
|
```
|
|
261
230
|
|
|
262
|
-
|
|
263
|
-
## Development
|
|
264
|
-
|
|
265
|
-
After checking out the repo, run `bin/setup` to install dependencies. Then, run `rake spec` to run the tests. You can also run `bin/console` for an interactive prompt that will allow you to experiment.
|
|
266
|
-
|
|
267
|
-
To install this gem onto your local machine, run `bundle exec rake install`. To release a new version, update the version number in `version.rb`, and then run `bundle exec rake release`, which will create a git tag for the version, push git commits and tags, and push the `.gem` file to [rubygems.org](https://rubygems.org).
|
|
268
|
-
|
|
269
231
|
## Contributing
|
|
270
232
|
|
|
271
233
|
Bug reports and pull requests are welcome on GitHub at https://github.com/yoshoku/rumale.
|
data/lib/rumale.rb
CHANGED
|
@@ -70,6 +70,7 @@ require 'rumale/ensemble/random_forest_regressor'
|
|
|
70
70
|
require 'rumale/ensemble/extra_trees_classifier'
|
|
71
71
|
require 'rumale/ensemble/extra_trees_regressor'
|
|
72
72
|
require 'rumale/clustering/k_means'
|
|
73
|
+
require 'rumale/clustering/mini_batch_k_means'
|
|
73
74
|
require 'rumale/clustering/k_medoids'
|
|
74
75
|
require 'rumale/clustering/gaussian_mixture'
|
|
75
76
|
require 'rumale/clustering/dbscan'
|
|
@@ -92,7 +93,9 @@ require 'rumale/neural_network/mlp_regressor'
|
|
|
92
93
|
require 'rumale/neural_network/mlp_classifier'
|
|
93
94
|
require 'rumale/feature_extraction/hash_vectorizer'
|
|
94
95
|
require 'rumale/feature_extraction/feature_hasher'
|
|
96
|
+
require 'rumale/feature_extraction/tfidf_transformer'
|
|
95
97
|
require 'rumale/preprocessing/l2_normalizer'
|
|
98
|
+
require 'rumale/preprocessing/l1_normalizer'
|
|
96
99
|
require 'rumale/preprocessing/min_max_scaler'
|
|
97
100
|
require 'rumale/preprocessing/max_abs_scaler'
|
|
98
101
|
require 'rumale/preprocessing/standard_scaler'
|
|
@@ -25,6 +25,7 @@ module Rumale
|
|
|
25
25
|
|
|
26
26
|
def enable_parallel?
|
|
27
27
|
return false if @params[:n_jobs].nil?
|
|
28
|
+
|
|
28
29
|
if defined?(Parallel).nil?
|
|
29
30
|
warn('If you want to use parallel option, you should install and load Parallel in advance.')
|
|
30
31
|
return false
|
|
@@ -34,6 +35,7 @@ module Rumale
|
|
|
34
35
|
|
|
35
36
|
def n_processes
|
|
36
37
|
return 1 unless enable_parallel?
|
|
38
|
+
|
|
37
39
|
@params[:n_jobs] <= 0 ? Parallel.processor_count : @params[:n_jobs]
|
|
38
40
|
end
|
|
39
41
|
|
|
@@ -13,7 +13,7 @@ module Rumale
|
|
|
13
13
|
# cluster_labels = analyzer.fit_predict(samples)
|
|
14
14
|
#
|
|
15
15
|
# *Reference*
|
|
16
|
-
# - M
|
|
16
|
+
# - Ester, M., Kriegel, H-P., Sander, J., and Xu, X., "A density-based algorithm for discovering clusters in large spatial databases with noise," Proc. KDD' 96, pp. 266--231, 1996.
|
|
17
17
|
class DBSCAN
|
|
18
18
|
include Base::BaseEstimator
|
|
19
19
|
include Base::ClusterAnalyzer
|
|
@@ -54,6 +54,7 @@ module Rumale
|
|
|
54
54
|
def fit(x, _y = nil)
|
|
55
55
|
x = check_convert_sample_array(x)
|
|
56
56
|
raise ArgumentError, 'Expect the input distance matrix to be square.' if @params[:metric] == 'precomputed' && x.shape[0] != x.shape[1]
|
|
57
|
+
|
|
57
58
|
partial_fit(x)
|
|
58
59
|
self
|
|
59
60
|
end
|
|
@@ -66,6 +67,7 @@ module Rumale
|
|
|
66
67
|
def fit_predict(x)
|
|
67
68
|
x = check_convert_sample_array(x)
|
|
68
69
|
raise ArgumentError, 'Expect the input distance matrix to be square.' if @params[:metric] == 'precomputed' && x.shape[0] != x.shape[1]
|
|
70
|
+
|
|
69
71
|
partial_fit(x)
|
|
70
72
|
labels
|
|
71
73
|
end
|
|
@@ -80,6 +82,7 @@ module Rumale
|
|
|
80
82
|
@labels = Numo::Int32.zeros(n_samples) - 2
|
|
81
83
|
n_samples.times do |query_id|
|
|
82
84
|
next if @labels[query_id] >= -1
|
|
85
|
+
|
|
83
86
|
cluster_id += 1 if expand_cluster(metric_mat, query_id, cluster_id)
|
|
84
87
|
end
|
|
85
88
|
@core_sample_ids = Numo::Int32[*@core_sample_ids.flatten]
|
|
@@ -102,6 +105,7 @@ module Rumale
|
|
|
102
105
|
while (m = target_ids.shift)
|
|
103
106
|
neighbor_ids = region_query(metric_mat[m, true])
|
|
104
107
|
next if neighbor_ids.size < @params[:min_samples]
|
|
108
|
+
|
|
105
109
|
neighbor_ids.each do |n|
|
|
106
110
|
target_ids.push(n) if @labels[n] < -1
|
|
107
111
|
@labels[n] = cluster_id if @labels[n] <= -1
|
|
@@ -86,6 +86,7 @@ module Rumale
|
|
|
86
86
|
new_memberships = calc_memberships(x, @weights, @means, @covariances, @params[:covariance_type])
|
|
87
87
|
error = (memberships - new_memberships).abs.max
|
|
88
88
|
break if error <= @params[:tol]
|
|
89
|
+
|
|
89
90
|
memberships = new_memberships.dup
|
|
90
91
|
end
|
|
91
92
|
self
|
|
@@ -209,6 +210,7 @@ module Rumale
|
|
|
209
210
|
|
|
210
211
|
def check_enable_linalg(method_name)
|
|
211
212
|
return unless @params[:covariance_type] == 'full' && !enable_linalg?
|
|
213
|
+
|
|
212
214
|
raise "GaussianMixture##{method_name} requires Numo::Linalg when covariance_type is 'full' but that is not loaded."
|
|
213
215
|
end
|
|
214
216
|
end
|
|
@@ -15,9 +15,9 @@ module Rumale
|
|
|
15
15
|
# cluster_labels = analyzer.fit_predict(samples)
|
|
16
16
|
#
|
|
17
17
|
# *Reference*
|
|
18
|
-
# - R J. G. B
|
|
19
|
-
# - R J. G. B
|
|
20
|
-
# - L
|
|
18
|
+
# - Campello, R J. G. B., Moulavi, D., Zimek, A., and Sander, J., "Hierarchical Density Estimates for Data Clustering, Visualization, and Outlier Detection," TKDD, Vol. 10 (1), pp. 5:1--5:51, 2015.
|
|
19
|
+
# - Campello, R J. G. B., Moulavi, D., and Sander, J., "Density-Based Clustering Based on Hierarchical Density Estimates," Proc. PAKDD'13, pp. 160--172, 2013.
|
|
20
|
+
# - Lelis, L., and Sander, J., "Semi-Supervised Density-Based Clustering," Proc. ICDM'09, pp. 842--847, 2009.
|
|
21
21
|
class HDBSCAN
|
|
22
22
|
include Base::BaseEstimator
|
|
23
23
|
include Base::ClusterAnalyzer
|
|
@@ -55,6 +55,7 @@ module Rumale
|
|
|
55
55
|
def fit(x, _y = nil)
|
|
56
56
|
x = check_convert_sample_array(x)
|
|
57
57
|
raise ArgumentError, 'Expect the input distance matrix to be square.' if @params[:metric] == 'precomputed' && x.shape[0] != x.shape[1]
|
|
58
|
+
|
|
58
59
|
fit_predict(x)
|
|
59
60
|
self
|
|
60
61
|
end
|
|
@@ -67,6 +68,7 @@ module Rumale
|
|
|
67
68
|
def fit_predict(x)
|
|
68
69
|
x = check_convert_sample_array(x)
|
|
69
70
|
raise ArgumentError, 'Expect the input distance matrix to be square.' if @params[:metric] == 'precomputed' && x.shape[0] != x.shape[1]
|
|
71
|
+
|
|
70
72
|
distance_mat = @params[:metric] == 'precomputed' ? x : Rumale::PairwiseMetric.euclidean_distance(x)
|
|
71
73
|
@labels = partial_fit(distance_mat)
|
|
72
74
|
end
|
|
@@ -15,7 +15,7 @@ module Rumale
|
|
|
15
15
|
# cluster_labels = analyzer.fit_predict(samples)
|
|
16
16
|
#
|
|
17
17
|
# *Reference*
|
|
18
|
-
# - D
|
|
18
|
+
# - Arthur, D., and Vassilvitskii, S., "k-means++: the advantages of careful seeding," Proc. SODA'07, pp. 1027--1035, 2007.
|
|
19
19
|
class KMeans
|
|
20
20
|
include Base::BaseEstimator
|
|
21
21
|
include Base::ClusterAnalyzer
|
|
@@ -106,6 +106,7 @@ module Rumale
|
|
|
106
106
|
rand_id = [*0...n_samples].sample(@params[:n_clusters], random: sub_rng)
|
|
107
107
|
@cluster_centers = x[rand_id, true].dup
|
|
108
108
|
return unless @params[:init] == 'k-means++'
|
|
109
|
+
|
|
109
110
|
# k-means++ initialize
|
|
110
111
|
(1...@params[:n_clusters]).each do |n|
|
|
111
112
|
distance_matrix = PairwiseMetric.euclidean_distance(x, @cluster_centers[0...n, true])
|
|
@@ -13,7 +13,7 @@ module Rumale
|
|
|
13
13
|
# cluster_labels = analyzer.fit_predict(samples)
|
|
14
14
|
#
|
|
15
15
|
# *Reference*
|
|
16
|
-
# - D
|
|
16
|
+
# - Arthur, D., and Vassilvitskii, S., "k-means++: the advantages of careful seeding," Proc. SODA'07, pp. 1027--1035, 2007.
|
|
17
17
|
class KMedoids
|
|
18
18
|
include Base::BaseEstimator
|
|
19
19
|
include Base::ClusterAnalyzer
|
|
@@ -64,6 +64,7 @@ module Rumale
|
|
|
64
64
|
def fit(x, _not_used = nil)
|
|
65
65
|
x = check_convert_sample_array(x)
|
|
66
66
|
raise ArgumentError, 'Expect the input distance matrix to be square.' if @params[:metric] == 'precomputed' && x.shape[0] != x.shape[1]
|
|
67
|
+
|
|
67
68
|
# initialize some varibales.
|
|
68
69
|
distance_mat = @params[:metric] == 'precomputed' ? x : Rumale::PairwiseMetric.euclidean_distance(x)
|
|
69
70
|
init_cluster_centers(distance_mat)
|
|
@@ -76,6 +77,7 @@ module Rumale
|
|
|
76
77
|
end
|
|
77
78
|
new_error = distance_mat[true, @medoid_ids].mean
|
|
78
79
|
break if (error - new_error).abs <= @params[:tol]
|
|
80
|
+
|
|
79
81
|
error = new_error
|
|
80
82
|
end
|
|
81
83
|
@cluster_centers = x[@medoid_ids, true].dup if @params[:metric] == 'euclidean'
|
|
@@ -93,6 +95,7 @@ module Rumale
|
|
|
93
95
|
if @params[:metric] == 'precomputed' && distance_mat.shape[1] != @medoid_ids.size
|
|
94
96
|
raise ArgumentError, 'Expect the size input matrix to be n_samples-by-n_clusters.'
|
|
95
97
|
end
|
|
98
|
+
|
|
96
99
|
assign_cluster(distance_mat)
|
|
97
100
|
end
|
|
98
101
|
|
|
@@ -123,6 +126,7 @@ module Rumale
|
|
|
123
126
|
sub_rng = @rng.dup
|
|
124
127
|
@medoid_ids = Numo::Int32.asarray([*0...n_samples].sample(@params[:n_clusters], random: sub_rng))
|
|
125
128
|
return unless @params[:init] == 'k-means++'
|
|
129
|
+
|
|
126
130
|
# k-means++ initialize
|
|
127
131
|
(1...@params[:n_clusters]).each do |n|
|
|
128
132
|
distances = distance_mat[true, @medoid_ids[0...n]]
|
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require 'rumale/base/base_estimator'
|
|
4
|
+
require 'rumale/base/cluster_analyzer'
|
|
5
|
+
require 'rumale/pairwise_metric'
|
|
6
|
+
|
|
7
|
+
module Rumale
|
|
8
|
+
module Clustering
|
|
9
|
+
# MniBatchKMeans is a class that implements K-Means cluster analysis
|
|
10
|
+
# with mini-batch stochastic gradient descent (SGD).
|
|
11
|
+
#
|
|
12
|
+
# @example
|
|
13
|
+
# analyzer = Rumale::Clustering::MiniBatchKMeans.new(n_clusters: 10, max_iter: 50, batch_size: 50, random_seed: 1)
|
|
14
|
+
# cluster_labels = analyzer.fit_predict(samples)
|
|
15
|
+
#
|
|
16
|
+
# *Reference*
|
|
17
|
+
# - Sculley, D., "Web-scale k-means clustering," Proc. WWW'10, pp. 1177--1178, 2010.
|
|
18
|
+
class MiniBatchKMeans
|
|
19
|
+
include Base::BaseEstimator
|
|
20
|
+
include Base::ClusterAnalyzer
|
|
21
|
+
|
|
22
|
+
# Return the centroids.
|
|
23
|
+
# @return [Numo::DFloat] (shape: [n_clusters, n_features])
|
|
24
|
+
attr_reader :cluster_centers
|
|
25
|
+
|
|
26
|
+
# Return the random generator.
|
|
27
|
+
# @return [Random]
|
|
28
|
+
attr_reader :rng
|
|
29
|
+
|
|
30
|
+
# Create a new cluster analyzer with K-Means method with mini-batch SGD.
|
|
31
|
+
#
|
|
32
|
+
# @param n_clusters [Integer] The number of clusters.
|
|
33
|
+
# @param init [String] The initialization method for centroids ('random' or 'k-means++').
|
|
34
|
+
# @param max_iter [Integer] The maximum number of iterations.
|
|
35
|
+
# @param batch_size [Integer] The size of the mini batches.
|
|
36
|
+
# @param tol [Float] The tolerance of termination criterion.
|
|
37
|
+
# @param random_seed [Integer] The seed value using to initialize the random generator.
|
|
38
|
+
def initialize(n_clusters: 8, init: 'k-means++', max_iter: 100, batch_size: 100, tol: 1.0e-4, random_seed: nil)
|
|
39
|
+
check_params_numeric(n_clusters: n_clusters, max_iter: max_iter, batch_size: batch_size, tol: tol)
|
|
40
|
+
check_params_string(init: init)
|
|
41
|
+
check_params_numeric_or_nil(random_seed: random_seed)
|
|
42
|
+
check_params_positive(n_clusters: n_clusters, max_iter: max_iter)
|
|
43
|
+
@params = {}
|
|
44
|
+
@params[:n_clusters] = n_clusters
|
|
45
|
+
@params[:init] = init == 'random' ? 'random' : 'k-means++'
|
|
46
|
+
@params[:max_iter] = max_iter
|
|
47
|
+
@params[:batch_size] = batch_size
|
|
48
|
+
@params[:tol] = tol
|
|
49
|
+
@params[:random_seed] = random_seed
|
|
50
|
+
@params[:random_seed] ||= srand
|
|
51
|
+
@cluster_centers = nil
|
|
52
|
+
@rng = Random.new(@params[:random_seed])
|
|
53
|
+
end
|
|
54
|
+
|
|
55
|
+
# Analysis clusters with given training data.
|
|
56
|
+
#
|
|
57
|
+
# @overload fit(x) -> MiniBatchKMeans
|
|
58
|
+
#
|
|
59
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
|
|
60
|
+
# @return [KMeans] The learned cluster analyzer itself.
|
|
61
|
+
def fit(x, _y = nil)
|
|
62
|
+
x = check_convert_sample_array(x)
|
|
63
|
+
# initialization.
|
|
64
|
+
n_samples = x.shape[0]
|
|
65
|
+
update_counter = Numo::Int32.zeros(@params[:n_clusters])
|
|
66
|
+
sub_rng = @rng.dup
|
|
67
|
+
init_cluster_centers(x, sub_rng)
|
|
68
|
+
# optimization with mini-batch sgd.
|
|
69
|
+
@params[:max_iter].times do |_t|
|
|
70
|
+
sample_ids = [*0...n_samples].shuffle(random: sub_rng)
|
|
71
|
+
old_centers = @cluster_centers.dup
|
|
72
|
+
until (subset_ids = sample_ids.shift(@params[:batch_size])).empty?
|
|
73
|
+
# sub sampling
|
|
74
|
+
sub_x = x[subset_ids, true]
|
|
75
|
+
# assign nearest centroids
|
|
76
|
+
cluster_labels = assign_cluster(sub_x)
|
|
77
|
+
# update centroids
|
|
78
|
+
@params[:n_clusters].times do |c|
|
|
79
|
+
assigned_bits = cluster_labels.eq(c)
|
|
80
|
+
next unless assigned_bits.count.positive?
|
|
81
|
+
|
|
82
|
+
update_counter[c] += 1
|
|
83
|
+
learning_rate = 1.fdiv(update_counter[c])
|
|
84
|
+
update = sub_x[assigned_bits.where, true].mean(axis: 0)
|
|
85
|
+
@cluster_centers[c, true] = (1 - learning_rate) * @cluster_centers[c, true] + learning_rate * update
|
|
86
|
+
end
|
|
87
|
+
end
|
|
88
|
+
error = Numo::NMath.sqrt(((old_centers - @cluster_centers)**2).sum(axis: 1)).mean
|
|
89
|
+
break if error <= @params[:tol]
|
|
90
|
+
end
|
|
91
|
+
self
|
|
92
|
+
end
|
|
93
|
+
|
|
94
|
+
# Predict cluster labels for samples.
|
|
95
|
+
#
|
|
96
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the cluster label.
|
|
97
|
+
# @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
|
|
98
|
+
def predict(x)
|
|
99
|
+
x = check_convert_sample_array(x)
|
|
100
|
+
assign_cluster(x)
|
|
101
|
+
end
|
|
102
|
+
|
|
103
|
+
# Analysis clusters and assign samples to clusters.
|
|
104
|
+
#
|
|
105
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
|
|
106
|
+
# @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
|
|
107
|
+
def fit_predict(x)
|
|
108
|
+
x = check_convert_sample_array(x)
|
|
109
|
+
fit(x)
|
|
110
|
+
predict(x)
|
|
111
|
+
end
|
|
112
|
+
|
|
113
|
+
private
|
|
114
|
+
|
|
115
|
+
def assign_cluster(x)
|
|
116
|
+
distance_matrix = PairwiseMetric.euclidean_distance(x, @cluster_centers)
|
|
117
|
+
distance_matrix.min_index(axis: 1) - Numo::Int32[*0.step(distance_matrix.size - 1, @cluster_centers.shape[0])]
|
|
118
|
+
end
|
|
119
|
+
|
|
120
|
+
def init_cluster_centers(x, sub_rng)
|
|
121
|
+
# random initialize
|
|
122
|
+
n_samples = x.shape[0]
|
|
123
|
+
rand_id = [*0...n_samples].sample(@params[:n_clusters], random: sub_rng)
|
|
124
|
+
@cluster_centers = x[rand_id, true].dup
|
|
125
|
+
return unless @params[:init] == 'k-means++'
|
|
126
|
+
|
|
127
|
+
# k-means++ initialize
|
|
128
|
+
(1...@params[:n_clusters]).each do |n|
|
|
129
|
+
distance_matrix = PairwiseMetric.euclidean_distance(x, @cluster_centers[0...n, true])
|
|
130
|
+
min_distances = distance_matrix.flatten[distance_matrix.min_index(axis: 1)]
|
|
131
|
+
probs = min_distances**2 / (min_distances**2).sum
|
|
132
|
+
cum_probs = probs.cumsum
|
|
133
|
+
selected_id = cum_probs.gt(sub_rng.rand).where.to_a.first
|
|
134
|
+
@cluster_centers[n, true] = x[selected_id, true].dup
|
|
135
|
+
end
|
|
136
|
+
end
|
|
137
|
+
end
|
|
138
|
+
end
|
|
139
|
+
end
|