rumale 0.18.4 → 0.19.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (93) hide show
  1. checksums.yaml +4 -4
  2. data/.rubocop.yml +31 -3
  3. data/.travis.yml +3 -3
  4. data/CHANGELOG.md +43 -0
  5. data/Gemfile +9 -0
  6. data/README.md +6 -44
  7. data/lib/rumale.rb +1 -0
  8. data/lib/rumale/base/base_estimator.rb +2 -0
  9. data/lib/rumale/clustering/dbscan.rb +5 -1
  10. data/lib/rumale/clustering/gaussian_mixture.rb +2 -0
  11. data/lib/rumale/clustering/hdbscan.rb +7 -3
  12. data/lib/rumale/clustering/k_means.rb +2 -1
  13. data/lib/rumale/clustering/k_medoids.rb +5 -1
  14. data/lib/rumale/clustering/mini_batch_k_means.rb +139 -0
  15. data/lib/rumale/clustering/power_iteration.rb +3 -1
  16. data/lib/rumale/clustering/single_linkage.rb +3 -1
  17. data/lib/rumale/clustering/snn.rb +2 -2
  18. data/lib/rumale/clustering/spectral_clustering.rb +2 -2
  19. data/lib/rumale/dataset.rb +2 -0
  20. data/lib/rumale/decomposition/factor_analysis.rb +3 -1
  21. data/lib/rumale/decomposition/fast_ica.rb +2 -2
  22. data/lib/rumale/decomposition/nmf.rb +1 -1
  23. data/lib/rumale/decomposition/pca.rb +25 -6
  24. data/lib/rumale/ensemble/ada_boost_classifier.rb +4 -1
  25. data/lib/rumale/ensemble/ada_boost_regressor.rb +4 -2
  26. data/lib/rumale/ensemble/extra_trees_classifier.rb +1 -1
  27. data/lib/rumale/ensemble/extra_trees_regressor.rb +1 -1
  28. data/lib/rumale/ensemble/gradient_boosting_classifier.rb +4 -4
  29. data/lib/rumale/ensemble/gradient_boosting_regressor.rb +7 -9
  30. data/lib/rumale/evaluation_measure/adjusted_rand_score.rb +1 -1
  31. data/lib/rumale/evaluation_measure/calinski_harabasz_score.rb +1 -1
  32. data/lib/rumale/evaluation_measure/davies_bouldin_score.rb +1 -1
  33. data/lib/rumale/evaluation_measure/function.rb +9 -5
  34. data/lib/rumale/evaluation_measure/mutual_information.rb +1 -1
  35. data/lib/rumale/evaluation_measure/normalized_mutual_information.rb +4 -2
  36. data/lib/rumale/evaluation_measure/precision_recall.rb +5 -0
  37. data/lib/rumale/evaluation_measure/purity.rb +1 -1
  38. data/lib/rumale/evaluation_measure/roc_auc.rb +3 -0
  39. data/lib/rumale/evaluation_measure/silhouette_score.rb +3 -1
  40. data/lib/rumale/feature_extraction/feature_hasher.rb +14 -1
  41. data/lib/rumale/feature_extraction/hash_vectorizer.rb +1 -0
  42. data/lib/rumale/kernel_approximation/nystroem.rb +1 -1
  43. data/lib/rumale/kernel_approximation/rbf.rb +1 -1
  44. data/lib/rumale/kernel_machine/kernel_fda.rb +2 -2
  45. data/lib/rumale/kernel_machine/kernel_pca.rb +1 -1
  46. data/lib/rumale/kernel_machine/kernel_ridge.rb +2 -0
  47. data/lib/rumale/kernel_machine/kernel_svc.rb +1 -1
  48. data/lib/rumale/linear_model/base_linear_model.rb +2 -0
  49. data/lib/rumale/linear_model/elastic_net.rb +3 -3
  50. data/lib/rumale/linear_model/lasso.rb +3 -3
  51. data/lib/rumale/linear_model/linear_regression.rb +2 -1
  52. data/lib/rumale/linear_model/logistic_regression.rb +3 -3
  53. data/lib/rumale/linear_model/ridge.rb +2 -1
  54. data/lib/rumale/linear_model/svc.rb +3 -3
  55. data/lib/rumale/linear_model/svr.rb +3 -3
  56. data/lib/rumale/manifold/mds.rb +3 -1
  57. data/lib/rumale/manifold/tsne.rb +6 -2
  58. data/lib/rumale/metric_learning/neighbourhood_component_analysis.rb +14 -1
  59. data/lib/rumale/model_selection/grid_search_cv.rb +1 -0
  60. data/lib/rumale/naive_bayes/bernoulli_nb.rb +1 -1
  61. data/lib/rumale/naive_bayes/multinomial_nb.rb +1 -1
  62. data/lib/rumale/nearest_neighbors/k_neighbors_classifier.rb +1 -0
  63. data/lib/rumale/nearest_neighbors/k_neighbors_regressor.rb +2 -0
  64. data/lib/rumale/nearest_neighbors/vp_tree.rb +6 -8
  65. data/lib/rumale/neural_network/adam.rb +2 -2
  66. data/lib/rumale/neural_network/base_mlp.rb +1 -0
  67. data/lib/rumale/optimizer/ada_grad.rb +4 -1
  68. data/lib/rumale/optimizer/adam.rb +4 -1
  69. data/lib/rumale/optimizer/nadam.rb +6 -1
  70. data/lib/rumale/optimizer/rmsprop.rb +5 -2
  71. data/lib/rumale/optimizer/sgd.rb +3 -0
  72. data/lib/rumale/optimizer/yellow_fin.rb +4 -1
  73. data/lib/rumale/pairwise_metric.rb +33 -0
  74. data/lib/rumale/pipeline/pipeline.rb +3 -0
  75. data/lib/rumale/polynomial_model/base_factorization_machine.rb +5 -0
  76. data/lib/rumale/polynomial_model/factorization_machine_classifier.rb +7 -2
  77. data/lib/rumale/polynomial_model/factorization_machine_regressor.rb +7 -2
  78. data/lib/rumale/preprocessing/one_hot_encoder.rb +3 -0
  79. data/lib/rumale/preprocessing/ordinal_encoder.rb +2 -0
  80. data/lib/rumale/preprocessing/polynomial_features.rb +1 -0
  81. data/lib/rumale/probabilistic_output.rb +4 -2
  82. data/lib/rumale/tree/base_decision_tree.rb +2 -0
  83. data/lib/rumale/tree/decision_tree_classifier.rb +1 -0
  84. data/lib/rumale/tree/extra_tree_classifier.rb +1 -1
  85. data/lib/rumale/tree/extra_tree_regressor.rb +1 -1
  86. data/lib/rumale/tree/gradient_tree_regressor.rb +5 -5
  87. data/lib/rumale/utils.rb +1 -0
  88. data/lib/rumale/validation.rb +7 -0
  89. data/lib/rumale/version.rb +1 -1
  90. data/rumale.gemspec +1 -13
  91. metadata +8 -135
  92. data/bin/console +0 -14
  93. data/bin/setup +0 -8
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: a4ec0f04029a88ea7950bb74bf9e0bd970cf8b546253da4ef4545244ad0adf68
4
- data.tar.gz: ed927fb58359fe49d4834593ec2034535f5ac8d37c773d6cb5a53227c4777f61
3
+ metadata.gz: f49170105721cfebcae9f1a424e9a858650d78225541a8cb63b0ad4c70734988
4
+ data.tar.gz: ecc35086328eee1066252e75b8cd638256039e93beebc0bce5714493fe72570b
5
5
  SHA512:
6
- metadata.gz: 58727ddc9c5c6f9c12ac8231f57295988795bf07ec6f748543fff924a4d8085472a210789992ba2a3cc4827300823d141313da1fd20a10be3b5970b40871e4e3
7
- data.tar.gz: cc1af7ff552a9fe516e0588081ad6645cb03d4c8dd774492ea35c3e1a28eece7383dc88604c41e1cc1417f5d05af1324a7b97c121cb7ee76bcc697e4173195be
6
+ metadata.gz: 68f432bb34ff6c8e467a91d7c7e3aa07e816c2dd8807defc9e4e82e7a720c925062dbd27c8a7ec3294ecef2d71041baead2510edaf03a1eee210dc811eede22d
7
+ data.tar.gz: 5854eacc12de6c3cdcdbab0f9b4e73fc64d1be0533732348da6b4d6dcb0be9f115e2415501b05148fd021fa844ac0c25adc1bb858432a02ca6fe19d30a3538c7
@@ -3,11 +3,11 @@ require:
3
3
  - rubocop-rspec
4
4
 
5
5
  AllCops:
6
- TargetRubyVersion: 2.3
6
+ TargetRubyVersion: 2.5
7
7
  DisplayCopNames: true
8
8
  DisplayStyleGuide: true
9
9
  Exclude:
10
- - 'bin/*'
10
+ - 'ext/rumale/extconf.rb'
11
11
  - 'rumale.gemspec'
12
12
  - 'Rakefile'
13
13
  - 'Gemfile'
@@ -15,10 +15,26 @@ AllCops:
15
15
  Style/Documentation:
16
16
  Enabled: false
17
17
 
18
+ Style/HashEachMethods:
19
+ Enabled: true
20
+
21
+ Style/HashTransformKeys:
22
+ Enabled: true
23
+
24
+ Style/HashTransformValues:
25
+ Enabled: true
26
+
27
+ Lint/RaiseException:
28
+ Enabled: true
29
+
30
+ Lint/StructNewOverride:
31
+ Enabled: true
32
+
18
33
  Layout/LineLength:
19
34
  Max: 145
20
35
  IgnoredPatterns: ['(\A|\s)#']
21
36
 
37
+
22
38
  Metrics/ModuleLength:
23
39
  Max: 200
24
40
 
@@ -54,14 +70,26 @@ Naming/MethodParameterName:
54
70
  Naming/ConstantName:
55
71
  Enabled: false
56
72
 
73
+ Style/ExponentialNotation:
74
+ Enabled: true
75
+
57
76
  Style/FormatStringToken:
58
77
  Enabled: false
59
78
 
60
79
  Style/NumericLiterals:
61
80
  Enabled: false
62
81
 
82
+ Style/SlicingWithRange:
83
+ Enabled: true
84
+
63
85
  Layout/EmptyLineAfterGuardClause:
64
- Enabled: false
86
+ Enabled: true
87
+
88
+ Layout/EmptyLinesAroundAttributeAccessor:
89
+ Enabled: true
90
+
91
+ Layout/SpaceAroundMethodCallOperator:
92
+ Enabled: true
65
93
 
66
94
  RSpec/MultipleExpectations:
67
95
  Enabled: false
@@ -1,6 +1,6 @@
1
- os: linux
2
- dist: xenial
1
+ ---
3
2
  language: ruby
3
+ cache: bundler
4
4
  rvm:
5
5
  - '2.4'
6
6
  - '2.5'
@@ -14,4 +14,4 @@ addons:
14
14
  - liblapacke-dev
15
15
 
16
16
  before_install:
17
- - gem install bundler -v 2.0.2
17
+ - gem install bundler -v 2.1.4
@@ -1,3 +1,46 @@
1
+ # 0.19.1
2
+ - Add cluster analysis class for [mini-batch K-Means](https://yoshoku.github.io/rumale/doc/Rumale/Clustering/MiniBatchKMeans.html).
3
+ - Fix some typos.
4
+
5
+ # 0.19.0
6
+ - Change mmh3 and mopti gem to non-runtime dependent library.
7
+ - The mmh3 gem is used in [FeatureHasher](https://yoshoku.github.io/rumale/doc/Rumale/FeatureExtraction/FeatureHasher.html).
8
+ You only need to require mmh3 gem when using FeatureHasher.
9
+ ```ruby
10
+ require 'mmh3'
11
+ require 'rumale'
12
+
13
+ encoder = Rumale::FeatureExtraction::FeatureHasher.new
14
+ ```
15
+ - The mopti gem is used in [NeighbourhoodComponentAnalysis](https://yoshoku.github.io/rumale/doc/Rumale/MetricLearning/NeighbourhoodComponentAnalysis.html).
16
+ You only need to require mopti gem when using NeighbourhoodComponentAnalysis.
17
+ ```ruby
18
+ require 'mopti'
19
+ require 'rumale'
20
+
21
+ transformer = Rumale::MetricLearning::NeighbourhoodComponentAnalysis.new
22
+ ```
23
+ - Change the default value of solver parameter on [PCA](https://yoshoku.github.io/rumale/doc/Rumale/Decomposition/PCA.html) to 'auto'.
24
+ If Numo::Linalg is loaded, 'evd' is selected for the solver, otherwise 'fpt' is selected.
25
+ - Deprecate [PolynomialModel](https://yoshoku.github.io/rumale/doc/Rumale/PolynomialModel.html), [Optimizer](https://yoshoku.github.io/rumale/doc/Rumale/Optimizer.html), and the estimators contained in them. They will be deleted in version 0.20.0.
26
+ - Many machine learning libraries do not contain factorization machine algorithms, they are provided by another compatible library.
27
+ In addition, there are no plans to implement estimators in PolynomialModel.
28
+ Thus, the author decided to deprecate PolynomialModel.
29
+ - Currently, the Optimizer classes are only used by PolynomialModel estimators.
30
+ Therefore, they have been deprecated together with PolynomialModel.
31
+
32
+ # 0.18.7
33
+ - Fix to convert target_name to string array in [classification_report method](https://yoshoku.github.io/rumale/doc/Rumale/EvaluationMeasure.html#classification_report-class_method).
34
+ - Refactor some codes with Rubocop.
35
+
36
+ # 0.18.6
37
+ - Fix some configuration files.
38
+ - Update API documentation.
39
+
40
+ # 0.18.5
41
+ - Add functions for calculation of cosine similarity and distance to [Rumale::PairwiseMetric](https://yoshoku.github.io/rumale/doc/Rumale/PairwiseMetric.html).
42
+ - Refactor some codes with Rubocop.
43
+
1
44
  # 0.18.4
2
45
  - Add transformer class for [KernelFDA](https://yoshoku.github.io/rumale/doc/Rumale/KernelMachine/KernelFDA.html).
3
46
  - Refactor [KernelPCA](https://yoshoku.github.io/rumale/doc/Rumale/KernelMachine/KernelPCA.html).
data/Gemfile CHANGED
@@ -2,3 +2,12 @@ source 'https://rubygems.org'
2
2
 
3
3
  # Specify your gem's dependencies in rumale.gemspec
4
4
  gemspec
5
+
6
+ gem 'coveralls', '~> 0.8'
7
+ gem 'mmh3', '>= 1.0'
8
+ gem 'mopti', '>= 0.1.0'
9
+ gem 'numo-linalg', '>= 0.1.4'
10
+ gem 'parallel', '>= 1.17.0'
11
+ gem 'rake', '~> 12.0'
12
+ gem 'rake-compiler', '~> 1.0'
13
+ gem 'rspec', '~> 3.0'
data/README.md CHANGED
@@ -11,7 +11,7 @@
11
11
  Rumale (**Ru**by **ma**chine **le**arning) is a machine learning library in Ruby.
12
12
  Rumale provides machine learning algorithms with interfaces similar to Scikit-Learn in Python.
13
13
  Rumale supports Support Vector Machine,
14
- Logistic Regression, Ridge, Lasso, Factorization Machine,
14
+ Logistic Regression, Ridge, Lasso,
15
15
  Multi-layer Perceptron,
16
16
  Naive Bayes, Decision Tree, Gradient Tree Boosting, Random Forest,
17
17
  K-Means, Gaussian Mixture Model, DBSCAN, Spectral Clustering,
@@ -42,39 +42,7 @@ Or install it yourself as:
42
42
 
43
43
  ## Usage
44
44
 
45
- ### Example 1. XOR data
46
- First, let's classify simple xor data.
47
-
48
- ```ruby
49
- require 'rumale'
50
-
51
- # Prepare XOR data.
52
- samples = [[0, 0], [0, 1], [1, 0], [1, 1]]
53
- labels = [0, 1, 1, 0]
54
-
55
- # Train classifier with nearest neighbor rule.
56
- estimator = Rumale::NearestNeighbors::KNeighborsClassifier.new(n_neighbors: 1)
57
- estimator.fit(samples, labels)
58
-
59
- # Predict labels.
60
- p labels
61
- p estimator.predict(samples)
62
- ```
63
-
64
- Execution of the above script result in the following.
65
-
66
- ```ruby
67
- [0, 1, 1, 0]
68
- Numo::Int32#shape=[4]
69
- [0, 1, 1, 0]
70
- ```
71
-
72
- The basic usage of Rumale is to first train the model with the fit method
73
- and then estimate with the predict method.
74
- In addition, Rumale recommends using arrays such as feature vectors and labels with
75
- [Numo::NArray](https://github.com/ruby-numo/numo-narray).
76
-
77
- ### Example 2. Pendigits dataset classification
45
+ ### Example 1. Pendigits dataset classification
78
46
 
79
47
  Rumale provides function loading libsvm format dataset file.
80
48
  We start by downloading the pendigits dataset from LIBSVM Data web site.
@@ -137,7 +105,7 @@ $ ruby test.rb
137
105
  Accuracy: 98.7%
138
106
  ```
139
107
 
140
- ### Example 3. Cross-validation
108
+ ### Example 2. Cross-validation
141
109
 
142
110
  ```ruby
143
111
  require 'rumale'
@@ -168,7 +136,7 @@ $ ruby cross_validation.rb
168
136
  5-CV mean log-loss: 0.355
169
137
  ```
170
138
 
171
- ### Example 4. Pipeline
139
+ ### Example 3. Pipeline
172
140
 
173
141
  ```ruby
174
142
  require 'rumale'
@@ -200,9 +168,10 @@ $ ruby pipeline.rb
200
168
  5-CV mean accuracy: 99.6 %
201
169
  ```
202
170
 
203
- ## Speeding up
171
+ ## Speed up
204
172
 
205
173
  ### Numo::Linalg
174
+ Rumale uses [Numo::NArray](https://github.com/ruby-numo/numo-narray) for typed arrays.
206
175
  Loading the [Numo::Linalg](https://github.com/ruby-numo/numo-linalg) allows to perform matrix product of Numo::NArray using BLAS libraries.
207
176
  For example, using the [OpenBLAS](https://github.com/xianyi/OpenBLAS) speeds up many estimators in Rumale.
208
177
 
@@ -259,13 +228,6 @@ When -1 is given to n_jobs parameter, all processors are used.
259
228
  estimator = Rumale::Ensemble::RandomForestClassifier.new(n_jobs: -1, random_seed: 1)
260
229
  ```
261
230
 
262
-
263
- ## Development
264
-
265
- After checking out the repo, run `bin/setup` to install dependencies. Then, run `rake spec` to run the tests. You can also run `bin/console` for an interactive prompt that will allow you to experiment.
266
-
267
- To install this gem onto your local machine, run `bundle exec rake install`. To release a new version, update the version number in `version.rb`, and then run `bundle exec rake release`, which will create a git tag for the version, push git commits and tags, and push the `.gem` file to [rubygems.org](https://rubygems.org).
268
-
269
231
  ## Contributing
270
232
 
271
233
  Bug reports and pull requests are welcome on GitHub at https://github.com/yoshoku/rumale.
@@ -70,6 +70,7 @@ require 'rumale/ensemble/random_forest_regressor'
70
70
  require 'rumale/ensemble/extra_trees_classifier'
71
71
  require 'rumale/ensemble/extra_trees_regressor'
72
72
  require 'rumale/clustering/k_means'
73
+ require 'rumale/clustering/mini_batch_k_means'
73
74
  require 'rumale/clustering/k_medoids'
74
75
  require 'rumale/clustering/gaussian_mixture'
75
76
  require 'rumale/clustering/dbscan'
@@ -25,6 +25,7 @@ module Rumale
25
25
 
26
26
  def enable_parallel?
27
27
  return false if @params[:n_jobs].nil?
28
+
28
29
  if defined?(Parallel).nil?
29
30
  warn('If you want to use parallel option, you should install and load Parallel in advance.')
30
31
  return false
@@ -34,6 +35,7 @@ module Rumale
34
35
 
35
36
  def n_processes
36
37
  return 1 unless enable_parallel?
38
+
37
39
  @params[:n_jobs] <= 0 ? Parallel.processor_count : @params[:n_jobs]
38
40
  end
39
41
 
@@ -13,7 +13,7 @@ module Rumale
13
13
  # cluster_labels = analyzer.fit_predict(samples)
14
14
  #
15
15
  # *Reference*
16
- # - M. Ester, H-P. Kriegel, J. Sander, and X. Xu, "A density-based algorithm for discovering clusters in large spatial databases with noise," Proc. KDD' 96, pp. 266--231, 1996.
16
+ # - Ester, M., Kriegel, H-P., Sander, J., and Xu, X., "A density-based algorithm for discovering clusters in large spatial databases with noise," Proc. KDD' 96, pp. 266--231, 1996.
17
17
  class DBSCAN
18
18
  include Base::BaseEstimator
19
19
  include Base::ClusterAnalyzer
@@ -54,6 +54,7 @@ module Rumale
54
54
  def fit(x, _y = nil)
55
55
  x = check_convert_sample_array(x)
56
56
  raise ArgumentError, 'Expect the input distance matrix to be square.' if @params[:metric] == 'precomputed' && x.shape[0] != x.shape[1]
57
+
57
58
  partial_fit(x)
58
59
  self
59
60
  end
@@ -66,6 +67,7 @@ module Rumale
66
67
  def fit_predict(x)
67
68
  x = check_convert_sample_array(x)
68
69
  raise ArgumentError, 'Expect the input distance matrix to be square.' if @params[:metric] == 'precomputed' && x.shape[0] != x.shape[1]
70
+
69
71
  partial_fit(x)
70
72
  labels
71
73
  end
@@ -80,6 +82,7 @@ module Rumale
80
82
  @labels = Numo::Int32.zeros(n_samples) - 2
81
83
  n_samples.times do |query_id|
82
84
  next if @labels[query_id] >= -1
85
+
83
86
  cluster_id += 1 if expand_cluster(metric_mat, query_id, cluster_id)
84
87
  end
85
88
  @core_sample_ids = Numo::Int32[*@core_sample_ids.flatten]
@@ -102,6 +105,7 @@ module Rumale
102
105
  while (m = target_ids.shift)
103
106
  neighbor_ids = region_query(metric_mat[m, true])
104
107
  next if neighbor_ids.size < @params[:min_samples]
108
+
105
109
  neighbor_ids.each do |n|
106
110
  target_ids.push(n) if @labels[n] < -1
107
111
  @labels[n] = cluster_id if @labels[n] <= -1
@@ -86,6 +86,7 @@ module Rumale
86
86
  new_memberships = calc_memberships(x, @weights, @means, @covariances, @params[:covariance_type])
87
87
  error = (memberships - new_memberships).abs.max
88
88
  break if error <= @params[:tol]
89
+
89
90
  memberships = new_memberships.dup
90
91
  end
91
92
  self
@@ -209,6 +210,7 @@ module Rumale
209
210
 
210
211
  def check_enable_linalg(method_name)
211
212
  return unless @params[:covariance_type] == 'full' && !enable_linalg?
213
+
212
214
  raise "GaussianMixture##{method_name} requires Numo::Linalg when covariance_type is 'full' but that is not loaded."
213
215
  end
214
216
  end
@@ -15,9 +15,9 @@ module Rumale
15
15
  # cluster_labels = analyzer.fit_predict(samples)
16
16
  #
17
17
  # *Reference*
18
- # - R J. G. B. Campello, D. Moulavi, A. Zimek, and J. Sander, "Hierarchical Density Estimates for Data Clustering, Visualization, and Outlier Detection," TKDD, Vol. 10 (1), pp. 5:1--5:51, 2015.
19
- # - R J. G. B. Campello, D. Moulavi, and J Sander, "Density-Based Clustering Based on Hierarchical Density Estimates," Proc. PAKDD'13, pp. 160--172, 2013.
20
- # - L. Lelis and J. Sander, "Semi-Supervised Density-Based Clustering," Proc. ICDM'09, pp. 842--847, 2009.
18
+ # - Campello, R J. G. B., Moulavi, D., Zimek, A., and Sander, J., "Hierarchical Density Estimates for Data Clustering, Visualization, and Outlier Detection," TKDD, Vol. 10 (1), pp. 5:1--5:51, 2015.
19
+ # - Campello, R J. G. B., Moulavi, D., and Sander, J., "Density-Based Clustering Based on Hierarchical Density Estimates," Proc. PAKDD'13, pp. 160--172, 2013.
20
+ # - Lelis, L., and Sander, J., "Semi-Supervised Density-Based Clustering," Proc. ICDM'09, pp. 842--847, 2009.
21
21
  class HDBSCAN
22
22
  include Base::BaseEstimator
23
23
  include Base::ClusterAnalyzer
@@ -55,6 +55,7 @@ module Rumale
55
55
  def fit(x, _y = nil)
56
56
  x = check_convert_sample_array(x)
57
57
  raise ArgumentError, 'Expect the input distance matrix to be square.' if @params[:metric] == 'precomputed' && x.shape[0] != x.shape[1]
58
+
58
59
  fit_predict(x)
59
60
  self
60
61
  end
@@ -67,6 +68,7 @@ module Rumale
67
68
  def fit_predict(x)
68
69
  x = check_convert_sample_array(x)
69
70
  raise ArgumentError, 'Expect the input distance matrix to be square.' if @params[:metric] == 'precomputed' && x.shape[0] != x.shape[1]
71
+
70
72
  distance_mat = @params[:metric] == 'precomputed' ? x : Rumale::PairwiseMetric.euclidean_distance(x)
71
73
  @labels = partial_fit(distance_mat)
72
74
  end
@@ -134,6 +136,7 @@ module Rumale
134
136
  res
135
137
  end
136
138
 
139
+ # rubocop:disable Metrics/AbcSize, Metrics/MethodLength, Metrics/PerceivedComplexity
137
140
  def condense_tree(hierarchy, min_cluster_size)
138
141
  n_edges = hierarchy.size
139
142
  root = 2 * n_edges
@@ -262,6 +265,7 @@ module Rumale
262
265
  end
263
266
  res
264
267
  end
268
+ # rubocop:enable Metrics/AbcSize, Metrics/MethodLength, Metrics/PerceivedComplexity
265
269
  end
266
270
  end
267
271
  end
@@ -15,7 +15,7 @@ module Rumale
15
15
  # cluster_labels = analyzer.fit_predict(samples)
16
16
  #
17
17
  # *Reference*
18
- # - D. Arthur and S. Vassilvitskii, "k-means++: the advantages of careful seeding," Proc. SODA'07, pp. 1027--1035, 2007.
18
+ # - Arthur, D., and Vassilvitskii, S., "k-means++: the advantages of careful seeding," Proc. SODA'07, pp. 1027--1035, 2007.
19
19
  class KMeans
20
20
  include Base::BaseEstimator
21
21
  include Base::ClusterAnalyzer
@@ -106,6 +106,7 @@ module Rumale
106
106
  rand_id = [*0...n_samples].sample(@params[:n_clusters], random: sub_rng)
107
107
  @cluster_centers = x[rand_id, true].dup
108
108
  return unless @params[:init] == 'k-means++'
109
+
109
110
  # k-means++ initialize
110
111
  (1...@params[:n_clusters]).each do |n|
111
112
  distance_matrix = PairwiseMetric.euclidean_distance(x, @cluster_centers[0...n, true])
@@ -13,7 +13,7 @@ module Rumale
13
13
  # cluster_labels = analyzer.fit_predict(samples)
14
14
  #
15
15
  # *Reference*
16
- # - D. Arthur and S. Vassilvitskii, "k-means++: the advantages of careful seeding," Proc. SODA'07, pp. 1027--1035, 2007.
16
+ # - Arthur, D., and Vassilvitskii, S., "k-means++: the advantages of careful seeding," Proc. SODA'07, pp. 1027--1035, 2007.
17
17
  class KMedoids
18
18
  include Base::BaseEstimator
19
19
  include Base::ClusterAnalyzer
@@ -64,6 +64,7 @@ module Rumale
64
64
  def fit(x, _not_used = nil)
65
65
  x = check_convert_sample_array(x)
66
66
  raise ArgumentError, 'Expect the input distance matrix to be square.' if @params[:metric] == 'precomputed' && x.shape[0] != x.shape[1]
67
+
67
68
  # initialize some varibales.
68
69
  distance_mat = @params[:metric] == 'precomputed' ? x : Rumale::PairwiseMetric.euclidean_distance(x)
69
70
  init_cluster_centers(distance_mat)
@@ -76,6 +77,7 @@ module Rumale
76
77
  end
77
78
  new_error = distance_mat[true, @medoid_ids].mean
78
79
  break if (error - new_error).abs <= @params[:tol]
80
+
79
81
  error = new_error
80
82
  end
81
83
  @cluster_centers = x[@medoid_ids, true].dup if @params[:metric] == 'euclidean'
@@ -93,6 +95,7 @@ module Rumale
93
95
  if @params[:metric] == 'precomputed' && distance_mat.shape[1] != @medoid_ids.size
94
96
  raise ArgumentError, 'Expect the size input matrix to be n_samples-by-n_clusters.'
95
97
  end
98
+
96
99
  assign_cluster(distance_mat)
97
100
  end
98
101
 
@@ -123,6 +126,7 @@ module Rumale
123
126
  sub_rng = @rng.dup
124
127
  @medoid_ids = Numo::Int32.asarray([*0...n_samples].sample(@params[:n_clusters], random: sub_rng))
125
128
  return unless @params[:init] == 'k-means++'
129
+
126
130
  # k-means++ initialize
127
131
  (1...@params[:n_clusters]).each do |n|
128
132
  distances = distance_mat[true, @medoid_ids[0...n]]
@@ -0,0 +1,139 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/base_estimator'
4
+ require 'rumale/base/cluster_analyzer'
5
+ require 'rumale/pairwise_metric'
6
+
7
+ module Rumale
8
+ module Clustering
9
+ # MniBatchKMeans is a class that implements K-Means cluster analysis
10
+ # with mini-batch stochastic gradient descent (SGD).
11
+ #
12
+ # @example
13
+ # analyzer = Rumale::Clustering::MiniBatchKMeans.new(n_clusters: 10, max_iter: 50, batch_size: 50, random_seed: 1)
14
+ # cluster_labels = analyzer.fit_predict(samples)
15
+ #
16
+ # *Reference*
17
+ # - Sculley, D., "Web-scale k-means clustering," Proc. WWW'10, pp. 1177--1178, 2010.
18
+ class MiniBatchKMeans
19
+ include Base::BaseEstimator
20
+ include Base::ClusterAnalyzer
21
+
22
+ # Return the centroids.
23
+ # @return [Numo::DFloat] (shape: [n_clusters, n_features])
24
+ attr_reader :cluster_centers
25
+
26
+ # Return the random generator.
27
+ # @return [Random]
28
+ attr_reader :rng
29
+
30
+ # Create a new cluster analyzer with K-Means method with mini-batch SGD.
31
+ #
32
+ # @param n_clusters [Integer] The number of clusters.
33
+ # @param init [String] The initialization method for centroids ('random' or 'k-means++').
34
+ # @param max_iter [Integer] The maximum number of iterations.
35
+ # @param batch_size [Integer] The size of the mini batches.
36
+ # @param tol [Float] The tolerance of termination criterion.
37
+ # @param random_seed [Integer] The seed value using to initialize the random generator.
38
+ def initialize(n_clusters: 8, init: 'k-means++', max_iter: 100, batch_size: 100, tol: 1.0e-4, random_seed: nil)
39
+ check_params_numeric(n_clusters: n_clusters, max_iter: max_iter, batch_size: batch_size, tol: tol)
40
+ check_params_string(init: init)
41
+ check_params_numeric_or_nil(random_seed: random_seed)
42
+ check_params_positive(n_clusters: n_clusters, max_iter: max_iter)
43
+ @params = {}
44
+ @params[:n_clusters] = n_clusters
45
+ @params[:init] = init == 'random' ? 'random' : 'k-means++'
46
+ @params[:max_iter] = max_iter
47
+ @params[:batch_size] = batch_size
48
+ @params[:tol] = tol
49
+ @params[:random_seed] = random_seed
50
+ @params[:random_seed] ||= srand
51
+ @cluster_centers = nil
52
+ @rng = Random.new(@params[:random_seed])
53
+ end
54
+
55
+ # Analysis clusters with given training data.
56
+ #
57
+ # @overload fit(x) -> MiniBatchKMeans
58
+ #
59
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
60
+ # @return [KMeans] The learned cluster analyzer itself.
61
+ def fit(x, _y = nil)
62
+ x = check_convert_sample_array(x)
63
+ # initialization.
64
+ n_samples = x.shape[0]
65
+ update_counter = Numo::Int32.zeros(@params[:n_clusters])
66
+ sub_rng = @rng.dup
67
+ init_cluster_centers(x, sub_rng)
68
+ # optimization with mini-batch sgd.
69
+ @params[:max_iter].times do |_t|
70
+ sample_ids = [*0...n_samples].shuffle(random: sub_rng)
71
+ old_centers = @cluster_centers.dup
72
+ until (subset_ids = sample_ids.shift(@params[:batch_size])).empty?
73
+ # sub sampling
74
+ sub_x = x[subset_ids, true]
75
+ # assign nearest centroids
76
+ cluster_labels = assign_cluster(sub_x)
77
+ # update centroids
78
+ @params[:n_clusters].times do |c|
79
+ assigned_bits = cluster_labels.eq(c)
80
+ next unless assigned_bits.count.positive?
81
+
82
+ update_counter[c] += 1
83
+ learning_rate = 1.fdiv(update_counter[c])
84
+ update = sub_x[assigned_bits.where, true].mean(axis: 0)
85
+ @cluster_centers[c, true] = (1 - learning_rate) * @cluster_centers[c, true] + learning_rate * update
86
+ end
87
+ end
88
+ error = Numo::NMath.sqrt(((old_centers - @cluster_centers)**2).sum(axis: 1)).mean
89
+ break if error <= @params[:tol]
90
+ end
91
+ self
92
+ end
93
+
94
+ # Predict cluster labels for samples.
95
+ #
96
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the cluster label.
97
+ # @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
98
+ def predict(x)
99
+ x = check_convert_sample_array(x)
100
+ assign_cluster(x)
101
+ end
102
+
103
+ # Analysis clusters and assign samples to clusters.
104
+ #
105
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
106
+ # @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
107
+ def fit_predict(x)
108
+ x = check_convert_sample_array(x)
109
+ fit(x)
110
+ predict(x)
111
+ end
112
+
113
+ private
114
+
115
+ def assign_cluster(x)
116
+ distance_matrix = PairwiseMetric.euclidean_distance(x, @cluster_centers)
117
+ distance_matrix.min_index(axis: 1) - Numo::Int32[*0.step(distance_matrix.size - 1, @cluster_centers.shape[0])]
118
+ end
119
+
120
+ def init_cluster_centers(x, sub_rng)
121
+ # random initialize
122
+ n_samples = x.shape[0]
123
+ rand_id = [*0...n_samples].sample(@params[:n_clusters], random: sub_rng)
124
+ @cluster_centers = x[rand_id, true].dup
125
+ return unless @params[:init] == 'k-means++'
126
+
127
+ # k-means++ initialize
128
+ (1...@params[:n_clusters]).each do |n|
129
+ distance_matrix = PairwiseMetric.euclidean_distance(x, @cluster_centers[0...n, true])
130
+ min_distances = distance_matrix.flatten[distance_matrix.min_index(axis: 1)]
131
+ probs = min_distances**2 / (min_distances**2).sum
132
+ cum_probs = probs.cumsum
133
+ selected_id = cum_probs.gt(sub_rng.rand).where.to_a.first
134
+ @cluster_centers[n, true] = x[selected_id, true].dup
135
+ end
136
+ end
137
+ end
138
+ end
139
+ end