rumale 0.18.6 → 0.19.3

Sign up to get free protection for your applications and to get access to all the features.
Files changed (74) hide show
  1. checksums.yaml +4 -4
  2. data/.rubocop.yml +80 -3
  3. data/CHANGELOG.md +45 -0
  4. data/Gemfile +2 -0
  5. data/README.md +5 -36
  6. data/lib/rumale.rb +5 -0
  7. data/lib/rumale/base/base_estimator.rb +2 -0
  8. data/lib/rumale/clustering/dbscan.rb +4 -0
  9. data/lib/rumale/clustering/gaussian_mixture.rb +2 -0
  10. data/lib/rumale/clustering/hdbscan.rb +3 -1
  11. data/lib/rumale/clustering/k_means.rb +2 -1
  12. data/lib/rumale/clustering/k_medoids.rb +5 -1
  13. data/lib/rumale/clustering/mini_batch_k_means.rb +139 -0
  14. data/lib/rumale/clustering/power_iteration.rb +2 -0
  15. data/lib/rumale/clustering/single_linkage.rb +2 -0
  16. data/lib/rumale/dataset.rb +5 -3
  17. data/lib/rumale/decomposition/factor_analysis.rb +2 -0
  18. data/lib/rumale/decomposition/pca.rb +24 -5
  19. data/lib/rumale/ensemble/ada_boost_classifier.rb +3 -0
  20. data/lib/rumale/ensemble/ada_boost_regressor.rb +3 -0
  21. data/lib/rumale/evaluation_measure/function.rb +2 -1
  22. data/lib/rumale/evaluation_measure/normalized_mutual_information.rb +2 -0
  23. data/lib/rumale/evaluation_measure/precision_recall.rb +5 -0
  24. data/lib/rumale/evaluation_measure/roc_auc.rb +3 -0
  25. data/lib/rumale/evaluation_measure/silhouette_score.rb +2 -0
  26. data/lib/rumale/feature_extraction/feature_hasher.rb +14 -1
  27. data/lib/rumale/feature_extraction/hash_vectorizer.rb +1 -0
  28. data/lib/rumale/feature_extraction/tfidf_transformer.rb +113 -0
  29. data/lib/rumale/kernel_approximation/nystroem.rb +1 -1
  30. data/lib/rumale/kernel_machine/kernel_ridge.rb +2 -0
  31. data/lib/rumale/kernel_machine/kernel_svc.rb +1 -1
  32. data/lib/rumale/linear_model/base_linear_model.rb +3 -1
  33. data/lib/rumale/linear_model/base_sgd.rb +1 -1
  34. data/lib/rumale/linear_model/linear_regression.rb +1 -0
  35. data/lib/rumale/linear_model/ridge.rb +1 -0
  36. data/lib/rumale/manifold/mds.rb +2 -0
  37. data/lib/rumale/manifold/tsne.rb +4 -0
  38. data/lib/rumale/metric_learning/neighbourhood_component_analysis.rb +14 -1
  39. data/lib/rumale/model_selection/cross_validation.rb +3 -2
  40. data/lib/rumale/model_selection/grid_search_cv.rb +1 -0
  41. data/lib/rumale/model_selection/k_fold.rb +1 -1
  42. data/lib/rumale/model_selection/shuffle_split.rb +1 -1
  43. data/lib/rumale/multiclass/one_vs_rest_classifier.rb +2 -2
  44. data/lib/rumale/nearest_neighbors/k_neighbors_classifier.rb +1 -0
  45. data/lib/rumale/nearest_neighbors/k_neighbors_regressor.rb +2 -0
  46. data/lib/rumale/nearest_neighbors/vp_tree.rb +1 -1
  47. data/lib/rumale/neural_network/adam.rb +1 -1
  48. data/lib/rumale/neural_network/base_mlp.rb +2 -1
  49. data/lib/rumale/optimizer/ada_grad.rb +3 -0
  50. data/lib/rumale/optimizer/adam.rb +3 -0
  51. data/lib/rumale/optimizer/nadam.rb +5 -0
  52. data/lib/rumale/optimizer/rmsprop.rb +3 -0
  53. data/lib/rumale/optimizer/sgd.rb +3 -0
  54. data/lib/rumale/optimizer/yellow_fin.rb +3 -0
  55. data/lib/rumale/pipeline/pipeline.rb +3 -0
  56. data/lib/rumale/polynomial_model/base_factorization_machine.rb +6 -1
  57. data/lib/rumale/polynomial_model/factorization_machine_classifier.rb +5 -0
  58. data/lib/rumale/polynomial_model/factorization_machine_regressor.rb +5 -0
  59. data/lib/rumale/preprocessing/binarizer.rb +60 -0
  60. data/lib/rumale/preprocessing/l1_normalizer.rb +62 -0
  61. data/lib/rumale/preprocessing/l2_normalizer.rb +2 -1
  62. data/lib/rumale/preprocessing/max_normalizer.rb +62 -0
  63. data/lib/rumale/preprocessing/one_hot_encoder.rb +3 -0
  64. data/lib/rumale/preprocessing/ordinal_encoder.rb +2 -0
  65. data/lib/rumale/preprocessing/polynomial_features.rb +1 -0
  66. data/lib/rumale/probabilistic_output.rb +2 -0
  67. data/lib/rumale/tree/base_decision_tree.rb +2 -0
  68. data/lib/rumale/tree/decision_tree_classifier.rb +1 -0
  69. data/lib/rumale/tree/gradient_tree_regressor.rb +1 -0
  70. data/lib/rumale/utils.rb +1 -0
  71. data/lib/rumale/validation.rb +7 -0
  72. data/lib/rumale/version.rb +1 -1
  73. data/rumale.gemspec +1 -3
  74. metadata +11 -34
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 40d5504cf4463721f53a4202ed99ec3f015c571fbadf3a4a4c7e0ac6eb00c7a7
4
- data.tar.gz: fae3bebad1e88aa166d9279e5f5a2de4ebbad5f79fd416fef68d33d4f66ba2c6
3
+ metadata.gz: dc3413c05ad7c365117adc4abbc304ff1851fa9a3ff69fef3c69e730d9a2b834
4
+ data.tar.gz: 8895cce8b350c4e245aabb5e3e4c4036655fa8d24a72f6481e0d2f8c9869fa54
5
5
  SHA512:
6
- metadata.gz: b9d32bc9bd5c5f37d27b06fcaa554c28f9a209debaaac4024c1c2a1f6fb367484ce760168f62a2d9e1ee24d9372ad9cccd1d36e7280f202734e5330105a995fa
7
- data.tar.gz: c18470cb533df4f6315324942afc98b5c52f4b7f6246078f459987a3407b79ae60a42599f40bc6236d5adba3dc85799a091e0d7ae5e9a1a3fd9fc626206cbef2
6
+ metadata.gz: fe10c975f286a4c9ac155d29310d61d1f180cbcc909ec7bdba3925973b6b9857635befc9bf4938cf28a6ef50c8011894b7b15768735f6200c27ce912907e5fb1
7
+ data.tar.gz: 327cce25145c1ca3f5623f84b4163560bdbee8245009c3d8e1c3318f61dec94b58a7395ebc3538e7b04a9702af08b077e7f426c54f7c5d4fd3b0fcf11c4744cf
@@ -3,7 +3,7 @@ require:
3
3
  - rubocop-rspec
4
4
 
5
5
  AllCops:
6
- TargetRubyVersion: 2.3
6
+ TargetRubyVersion: 2.5
7
7
  DisplayCopNames: true
8
8
  DisplayStyleGuide: true
9
9
  Exclude:
@@ -24,6 +24,15 @@ Style/HashTransformKeys:
24
24
  Style/HashTransformValues:
25
25
  Enabled: true
26
26
 
27
+ Lint/DeprecatedOpenSSLConstant:
28
+ Enabled: true
29
+
30
+ Lint/DuplicateElsifCondition:
31
+ Enabled: true
32
+
33
+ Lint/MixedRegexpCaptureTypes:
34
+ Enabled: true
35
+
27
36
  Lint/RaiseException:
28
37
  Enabled: true
29
38
 
@@ -34,7 +43,6 @@ Layout/LineLength:
34
43
  Max: 145
35
44
  IgnoredPatterns: ['(\A|\s)#']
36
45
 
37
-
38
46
  Metrics/ModuleLength:
39
47
  Max: 200
40
48
 
@@ -70,14 +78,83 @@ Naming/MethodParameterName:
70
78
  Naming/ConstantName:
71
79
  Enabled: false
72
80
 
81
+ Style/AccessorGrouping:
82
+ Enabled: true
83
+
84
+ Style/ArrayCoercion:
85
+ Enabled: true
86
+
87
+ Style/BisectedAttrAccessor:
88
+ Enabled: true
89
+
90
+ Style/CaseLikeIf:
91
+ Enabled: true
92
+
93
+ Style/ExponentialNotation:
94
+ Enabled: true
95
+
73
96
  Style/FormatStringToken:
74
97
  Enabled: false
75
98
 
99
+ Style/HashAsLastArrayItem:
100
+ Enabled: true
101
+
102
+ Style/HashLikeCase:
103
+ Enabled: true
104
+
76
105
  Style/NumericLiterals:
77
106
  Enabled: false
78
107
 
108
+ Style/RedundantAssignment:
109
+ Enabled: true
110
+
111
+ Style/RedundantFetchBlock:
112
+ Enabled: true
113
+
114
+ Style/RedundantFileExtensionInRequire:
115
+ Enabled: true
116
+
117
+ Style/RedundantRegexpCharacterClass:
118
+ Enabled: true
119
+
120
+ Style/RedundantRegexpEscape:
121
+ Enabled: true
122
+
123
+ Style/SlicingWithRange:
124
+ Enabled: true
125
+
79
126
  Layout/EmptyLineAfterGuardClause:
80
- Enabled: false
127
+ Enabled: true
128
+
129
+ Layout/EmptyLinesAroundAttributeAccessor:
130
+ Enabled: true
131
+
132
+ Layout/SpaceAroundMethodCallOperator:
133
+ Enabled: true
134
+
135
+ Performance/AncestorsInclude:
136
+ Enabled: true
137
+
138
+ Performance/BigDecimalWithNumericArgument:
139
+ Enabled: true
140
+
141
+ Performance/RedundantSortBlock:
142
+ Enabled: true
143
+
144
+ Performance/RedundantStringChars:
145
+ Enabled: true
146
+
147
+ Performance/ReverseFirst:
148
+ Enabled: true
149
+
150
+ Performance/SortReverse:
151
+ Enabled: true
152
+
153
+ Performance/Squeeze:
154
+ Enabled: true
155
+
156
+ Performance/StringInclude:
157
+ Enabled: true
81
158
 
82
159
  RSpec/MultipleExpectations:
83
160
  Enabled: false
@@ -1,3 +1,48 @@
1
+ # 0.19.3
2
+ - Add preprocessing class for [Binarizer](https://yoshoku.github.io/rumale/doc/Rumale/Preprocessing/Binarizer.html)
3
+ - Add preprocessing class for [MaxNormalizer](https://yoshoku.github.io/rumale/doc/Rumale/Preprocessing/MaxNormalizer.html)
4
+ - Refactor some codes with Rubocop.
5
+
6
+ # 0.19.2
7
+ - Fix L2Normalizer to avoid zero divide.
8
+ - Add preprocssing class for [L1Normalizer](https://yoshoku.github.io/rumale/doc/Rumale/Preprocessing/L1Normalizer.html).
9
+ - Add transformer class for [TfidfTransformer](https://yoshoku.github.io/rumale/doc/Rumale/FeatureExtraction/TfidfTransformer.html).
10
+
11
+ # 0.19.1
12
+ - Add cluster analysis class for [mini-batch K-Means](https://yoshoku.github.io/rumale/doc/Rumale/Clustering/MiniBatchKMeans.html).
13
+ - Fix some typos.
14
+
15
+ # 0.19.0
16
+ - Change mmh3 and mopti gem to non-runtime dependent library.
17
+ - The mmh3 gem is used in [FeatureHasher](https://yoshoku.github.io/rumale/doc/Rumale/FeatureExtraction/FeatureHasher.html).
18
+ You only need to require mmh3 gem when using FeatureHasher.
19
+ ```ruby
20
+ require 'mmh3'
21
+ require 'rumale'
22
+
23
+ encoder = Rumale::FeatureExtraction::FeatureHasher.new
24
+ ```
25
+ - The mopti gem is used in [NeighbourhoodComponentAnalysis](https://yoshoku.github.io/rumale/doc/Rumale/MetricLearning/NeighbourhoodComponentAnalysis.html).
26
+ You only need to require mopti gem when using NeighbourhoodComponentAnalysis.
27
+ ```ruby
28
+ require 'mopti'
29
+ require 'rumale'
30
+
31
+ transformer = Rumale::MetricLearning::NeighbourhoodComponentAnalysis.new
32
+ ```
33
+ - Change the default value of solver parameter on [PCA](https://yoshoku.github.io/rumale/doc/Rumale/Decomposition/PCA.html) to 'auto'.
34
+ If Numo::Linalg is loaded, 'evd' is selected for the solver, otherwise 'fpt' is selected.
35
+ - Deprecate [PolynomialModel](https://yoshoku.github.io/rumale/doc/Rumale/PolynomialModel.html), [Optimizer](https://yoshoku.github.io/rumale/doc/Rumale/Optimizer.html), and the estimators contained in them. They will be deleted in version 0.20.0.
36
+ - Many machine learning libraries do not contain factorization machine algorithms, they are provided by another compatible library.
37
+ In addition, there are no plans to implement estimators in PolynomialModel.
38
+ Thus, the author decided to deprecate PolynomialModel.
39
+ - Currently, the Optimizer classes are only used by PolynomialModel estimators.
40
+ Therefore, they have been deprecated together with PolynomialModel.
41
+
42
+ # 0.18.7
43
+ - Fix to convert target_name to string array in [classification_report method](https://yoshoku.github.io/rumale/doc/Rumale/EvaluationMeasure.html#classification_report-class_method).
44
+ - Refactor some codes with Rubocop.
45
+
1
46
  # 0.18.6
2
47
  - Fix some configuration files.
3
48
  - Update API documentation.
data/Gemfile CHANGED
@@ -4,6 +4,8 @@ source 'https://rubygems.org'
4
4
  gemspec
5
5
 
6
6
  gem 'coveralls', '~> 0.8'
7
+ gem 'mmh3', '>= 1.0'
8
+ gem 'mopti', '>= 0.1.0'
7
9
  gem 'numo-linalg', '>= 0.1.4'
8
10
  gem 'parallel', '>= 1.17.0'
9
11
  gem 'rake', '~> 12.0'
data/README.md CHANGED
@@ -11,7 +11,7 @@
11
11
  Rumale (**Ru**by **ma**chine **le**arning) is a machine learning library in Ruby.
12
12
  Rumale provides machine learning algorithms with interfaces similar to Scikit-Learn in Python.
13
13
  Rumale supports Support Vector Machine,
14
- Logistic Regression, Ridge, Lasso, Factorization Machine,
14
+ Logistic Regression, Ridge, Lasso,
15
15
  Multi-layer Perceptron,
16
16
  Naive Bayes, Decision Tree, Gradient Tree Boosting, Random Forest,
17
17
  K-Means, Gaussian Mixture Model, DBSCAN, Spectral Clustering,
@@ -42,39 +42,7 @@ Or install it yourself as:
42
42
 
43
43
  ## Usage
44
44
 
45
- ### Example 1. XOR data
46
- First, let's classify simple xor data.
47
-
48
- ```ruby
49
- require 'rumale'
50
-
51
- # Prepare XOR data.
52
- samples = [[0, 0], [0, 1], [1, 0], [1, 1]]
53
- labels = [0, 1, 1, 0]
54
-
55
- # Train classifier with nearest neighbor rule.
56
- estimator = Rumale::NearestNeighbors::KNeighborsClassifier.new(n_neighbors: 1)
57
- estimator.fit(samples, labels)
58
-
59
- # Predict labels.
60
- p labels
61
- p estimator.predict(samples)
62
- ```
63
-
64
- Execution of the above script result in the following.
65
-
66
- ```ruby
67
- [0, 1, 1, 0]
68
- Numo::Int32#shape=[4]
69
- [0, 1, 1, 0]
70
- ```
71
-
72
- The basic usage of Rumale is to first train the model with the fit method
73
- and then estimate with the predict method.
74
- In addition, Rumale recommends using arrays such as feature vectors and labels with
75
- [Numo::NArray](https://github.com/ruby-numo/numo-narray).
76
-
77
- ### Example 2. Pendigits dataset classification
45
+ ### Example 1. Pendigits dataset classification
78
46
 
79
47
  Rumale provides function loading libsvm format dataset file.
80
48
  We start by downloading the pendigits dataset from LIBSVM Data web site.
@@ -137,7 +105,7 @@ $ ruby test.rb
137
105
  Accuracy: 98.7%
138
106
  ```
139
107
 
140
- ### Example 3. Cross-validation
108
+ ### Example 2. Cross-validation
141
109
 
142
110
  ```ruby
143
111
  require 'rumale'
@@ -168,7 +136,7 @@ $ ruby cross_validation.rb
168
136
  5-CV mean log-loss: 0.355
169
137
  ```
170
138
 
171
- ### Example 4. Pipeline
139
+ ### Example 3. Pipeline
172
140
 
173
141
  ```ruby
174
142
  require 'rumale'
@@ -203,6 +171,7 @@ $ ruby pipeline.rb
203
171
  ## Speed up
204
172
 
205
173
  ### Numo::Linalg
174
+ Rumale uses [Numo::NArray](https://github.com/ruby-numo/numo-narray) for typed arrays.
206
175
  Loading the [Numo::Linalg](https://github.com/ruby-numo/numo-linalg) allows to perform matrix product of Numo::NArray using BLAS libraries.
207
176
  For example, using the [OpenBLAS](https://github.com/xianyi/OpenBLAS) speeds up many estimators in Rumale.
208
177
 
@@ -70,6 +70,7 @@ require 'rumale/ensemble/random_forest_regressor'
70
70
  require 'rumale/ensemble/extra_trees_classifier'
71
71
  require 'rumale/ensemble/extra_trees_regressor'
72
72
  require 'rumale/clustering/k_means'
73
+ require 'rumale/clustering/mini_batch_k_means'
73
74
  require 'rumale/clustering/k_medoids'
74
75
  require 'rumale/clustering/gaussian_mixture'
75
76
  require 'rumale/clustering/dbscan'
@@ -92,7 +93,10 @@ require 'rumale/neural_network/mlp_regressor'
92
93
  require 'rumale/neural_network/mlp_classifier'
93
94
  require 'rumale/feature_extraction/hash_vectorizer'
94
95
  require 'rumale/feature_extraction/feature_hasher'
96
+ require 'rumale/feature_extraction/tfidf_transformer'
95
97
  require 'rumale/preprocessing/l2_normalizer'
98
+ require 'rumale/preprocessing/l1_normalizer'
99
+ require 'rumale/preprocessing/max_normalizer'
96
100
  require 'rumale/preprocessing/min_max_scaler'
97
101
  require 'rumale/preprocessing/max_abs_scaler'
98
102
  require 'rumale/preprocessing/standard_scaler'
@@ -101,6 +105,7 @@ require 'rumale/preprocessing/label_binarizer'
101
105
  require 'rumale/preprocessing/label_encoder'
102
106
  require 'rumale/preprocessing/one_hot_encoder'
103
107
  require 'rumale/preprocessing/ordinal_encoder'
108
+ require 'rumale/preprocessing/binarizer'
104
109
  require 'rumale/preprocessing/polynomial_features'
105
110
  require 'rumale/model_selection/k_fold'
106
111
  require 'rumale/model_selection/stratified_k_fold'
@@ -25,6 +25,7 @@ module Rumale
25
25
 
26
26
  def enable_parallel?
27
27
  return false if @params[:n_jobs].nil?
28
+
28
29
  if defined?(Parallel).nil?
29
30
  warn('If you want to use parallel option, you should install and load Parallel in advance.')
30
31
  return false
@@ -34,6 +35,7 @@ module Rumale
34
35
 
35
36
  def n_processes
36
37
  return 1 unless enable_parallel?
38
+
37
39
  @params[:n_jobs] <= 0 ? Parallel.processor_count : @params[:n_jobs]
38
40
  end
39
41
 
@@ -54,6 +54,7 @@ module Rumale
54
54
  def fit(x, _y = nil)
55
55
  x = check_convert_sample_array(x)
56
56
  raise ArgumentError, 'Expect the input distance matrix to be square.' if @params[:metric] == 'precomputed' && x.shape[0] != x.shape[1]
57
+
57
58
  partial_fit(x)
58
59
  self
59
60
  end
@@ -66,6 +67,7 @@ module Rumale
66
67
  def fit_predict(x)
67
68
  x = check_convert_sample_array(x)
68
69
  raise ArgumentError, 'Expect the input distance matrix to be square.' if @params[:metric] == 'precomputed' && x.shape[0] != x.shape[1]
70
+
69
71
  partial_fit(x)
70
72
  labels
71
73
  end
@@ -80,6 +82,7 @@ module Rumale
80
82
  @labels = Numo::Int32.zeros(n_samples) - 2
81
83
  n_samples.times do |query_id|
82
84
  next if @labels[query_id] >= -1
85
+
83
86
  cluster_id += 1 if expand_cluster(metric_mat, query_id, cluster_id)
84
87
  end
85
88
  @core_sample_ids = Numo::Int32[*@core_sample_ids.flatten]
@@ -102,6 +105,7 @@ module Rumale
102
105
  while (m = target_ids.shift)
103
106
  neighbor_ids = region_query(metric_mat[m, true])
104
107
  next if neighbor_ids.size < @params[:min_samples]
108
+
105
109
  neighbor_ids.each do |n|
106
110
  target_ids.push(n) if @labels[n] < -1
107
111
  @labels[n] = cluster_id if @labels[n] <= -1
@@ -86,6 +86,7 @@ module Rumale
86
86
  new_memberships = calc_memberships(x, @weights, @means, @covariances, @params[:covariance_type])
87
87
  error = (memberships - new_memberships).abs.max
88
88
  break if error <= @params[:tol]
89
+
89
90
  memberships = new_memberships.dup
90
91
  end
91
92
  self
@@ -209,6 +210,7 @@ module Rumale
209
210
 
210
211
  def check_enable_linalg(method_name)
211
212
  return unless @params[:covariance_type] == 'full' && !enable_linalg?
213
+
212
214
  raise "GaussianMixture##{method_name} requires Numo::Linalg when covariance_type is 'full' but that is not loaded."
213
215
  end
214
216
  end
@@ -55,6 +55,7 @@ module Rumale
55
55
  def fit(x, _y = nil)
56
56
  x = check_convert_sample_array(x)
57
57
  raise ArgumentError, 'Expect the input distance matrix to be square.' if @params[:metric] == 'precomputed' && x.shape[0] != x.shape[1]
58
+
58
59
  fit_predict(x)
59
60
  self
60
61
  end
@@ -67,6 +68,7 @@ module Rumale
67
68
  def fit_predict(x)
68
69
  x = check_convert_sample_array(x)
69
70
  raise ArgumentError, 'Expect the input distance matrix to be square.' if @params[:metric] == 'precomputed' && x.shape[0] != x.shape[1]
71
+
70
72
  distance_mat = @params[:metric] == 'precomputed' ? x : Rumale::PairwiseMetric.euclidean_distance(x)
71
73
  @labels = partial_fit(distance_mat)
72
74
  end
@@ -230,7 +232,7 @@ module Rumale
230
232
  end
231
233
 
232
234
  def flatten(tree, stabilities)
233
- node_ids = stabilities.keys.sort { |a, b| b <=> a }.slice(0, stabilities.size - 1)
235
+ node_ids = stabilities.keys.sort.reverse.slice(0, stabilities.size - 1)
234
236
 
235
237
  cluster_tree = tree.select { |edge| edge.n_elements > 1 }
236
238
  is_cluster = node_ids.each_with_object({}) { |n_id, h| h[n_id] = true }
@@ -103,9 +103,10 @@ module Rumale
103
103
  # random initialize
104
104
  n_samples = x.shape[0]
105
105
  sub_rng = @rng.dup
106
- rand_id = [*0...n_samples].sample(@params[:n_clusters], random: sub_rng)
106
+ rand_id = Array(0...n_samples).sample(@params[:n_clusters], random: sub_rng)
107
107
  @cluster_centers = x[rand_id, true].dup
108
108
  return unless @params[:init] == 'k-means++'
109
+
109
110
  # k-means++ initialize
110
111
  (1...@params[:n_clusters]).each do |n|
111
112
  distance_matrix = PairwiseMetric.euclidean_distance(x, @cluster_centers[0...n, true])
@@ -64,6 +64,7 @@ module Rumale
64
64
  def fit(x, _not_used = nil)
65
65
  x = check_convert_sample_array(x)
66
66
  raise ArgumentError, 'Expect the input distance matrix to be square.' if @params[:metric] == 'precomputed' && x.shape[0] != x.shape[1]
67
+
67
68
  # initialize some varibales.
68
69
  distance_mat = @params[:metric] == 'precomputed' ? x : Rumale::PairwiseMetric.euclidean_distance(x)
69
70
  init_cluster_centers(distance_mat)
@@ -76,6 +77,7 @@ module Rumale
76
77
  end
77
78
  new_error = distance_mat[true, @medoid_ids].mean
78
79
  break if (error - new_error).abs <= @params[:tol]
80
+
79
81
  error = new_error
80
82
  end
81
83
  @cluster_centers = x[@medoid_ids, true].dup if @params[:metric] == 'euclidean'
@@ -93,6 +95,7 @@ module Rumale
93
95
  if @params[:metric] == 'precomputed' && distance_mat.shape[1] != @medoid_ids.size
94
96
  raise ArgumentError, 'Expect the size input matrix to be n_samples-by-n_clusters.'
95
97
  end
98
+
96
99
  assign_cluster(distance_mat)
97
100
  end
98
101
 
@@ -121,8 +124,9 @@ module Rumale
121
124
  # random initialize
122
125
  n_samples = distance_mat.shape[0]
123
126
  sub_rng = @rng.dup
124
- @medoid_ids = Numo::Int32.asarray([*0...n_samples].sample(@params[:n_clusters], random: sub_rng))
127
+ @medoid_ids = Numo::Int32.asarray(Array(0...n_samples).sample(@params[:n_clusters], random: sub_rng))
125
128
  return unless @params[:init] == 'k-means++'
129
+
126
130
  # k-means++ initialize
127
131
  (1...@params[:n_clusters]).each do |n|
128
132
  distances = distance_mat[true, @medoid_ids[0...n]]
@@ -0,0 +1,139 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/base_estimator'
4
+ require 'rumale/base/cluster_analyzer'
5
+ require 'rumale/pairwise_metric'
6
+
7
+ module Rumale
8
+ module Clustering
9
+ # MniBatchKMeans is a class that implements K-Means cluster analysis
10
+ # with mini-batch stochastic gradient descent (SGD).
11
+ #
12
+ # @example
13
+ # analyzer = Rumale::Clustering::MiniBatchKMeans.new(n_clusters: 10, max_iter: 50, batch_size: 50, random_seed: 1)
14
+ # cluster_labels = analyzer.fit_predict(samples)
15
+ #
16
+ # *Reference*
17
+ # - Sculley, D., "Web-scale k-means clustering," Proc. WWW'10, pp. 1177--1178, 2010.
18
+ class MiniBatchKMeans
19
+ include Base::BaseEstimator
20
+ include Base::ClusterAnalyzer
21
+
22
+ # Return the centroids.
23
+ # @return [Numo::DFloat] (shape: [n_clusters, n_features])
24
+ attr_reader :cluster_centers
25
+
26
+ # Return the random generator.
27
+ # @return [Random]
28
+ attr_reader :rng
29
+
30
+ # Create a new cluster analyzer with K-Means method with mini-batch SGD.
31
+ #
32
+ # @param n_clusters [Integer] The number of clusters.
33
+ # @param init [String] The initialization method for centroids ('random' or 'k-means++').
34
+ # @param max_iter [Integer] The maximum number of iterations.
35
+ # @param batch_size [Integer] The size of the mini batches.
36
+ # @param tol [Float] The tolerance of termination criterion.
37
+ # @param random_seed [Integer] The seed value using to initialize the random generator.
38
+ def initialize(n_clusters: 8, init: 'k-means++', max_iter: 100, batch_size: 100, tol: 1.0e-4, random_seed: nil)
39
+ check_params_numeric(n_clusters: n_clusters, max_iter: max_iter, batch_size: batch_size, tol: tol)
40
+ check_params_string(init: init)
41
+ check_params_numeric_or_nil(random_seed: random_seed)
42
+ check_params_positive(n_clusters: n_clusters, max_iter: max_iter)
43
+ @params = {}
44
+ @params[:n_clusters] = n_clusters
45
+ @params[:init] = init == 'random' ? 'random' : 'k-means++'
46
+ @params[:max_iter] = max_iter
47
+ @params[:batch_size] = batch_size
48
+ @params[:tol] = tol
49
+ @params[:random_seed] = random_seed
50
+ @params[:random_seed] ||= srand
51
+ @cluster_centers = nil
52
+ @rng = Random.new(@params[:random_seed])
53
+ end
54
+
55
+ # Analysis clusters with given training data.
56
+ #
57
+ # @overload fit(x) -> MiniBatchKMeans
58
+ #
59
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
60
+ # @return [KMeans] The learned cluster analyzer itself.
61
+ def fit(x, _y = nil)
62
+ x = check_convert_sample_array(x)
63
+ # initialization.
64
+ n_samples = x.shape[0]
65
+ update_counter = Numo::Int32.zeros(@params[:n_clusters])
66
+ sub_rng = @rng.dup
67
+ init_cluster_centers(x, sub_rng)
68
+ # optimization with mini-batch sgd.
69
+ @params[:max_iter].times do |_t|
70
+ sample_ids = Array(0...n_samples).shuffle(random: sub_rng)
71
+ old_centers = @cluster_centers.dup
72
+ until (subset_ids = sample_ids.shift(@params[:batch_size])).empty?
73
+ # sub sampling
74
+ sub_x = x[subset_ids, true]
75
+ # assign nearest centroids
76
+ cluster_labels = assign_cluster(sub_x)
77
+ # update centroids
78
+ @params[:n_clusters].times do |c|
79
+ assigned_bits = cluster_labels.eq(c)
80
+ next unless assigned_bits.count.positive?
81
+
82
+ update_counter[c] += 1
83
+ learning_rate = 1.fdiv(update_counter[c])
84
+ update = sub_x[assigned_bits.where, true].mean(axis: 0)
85
+ @cluster_centers[c, true] = (1 - learning_rate) * @cluster_centers[c, true] + learning_rate * update
86
+ end
87
+ end
88
+ error = Numo::NMath.sqrt(((old_centers - @cluster_centers)**2).sum(axis: 1)).mean
89
+ break if error <= @params[:tol]
90
+ end
91
+ self
92
+ end
93
+
94
+ # Predict cluster labels for samples.
95
+ #
96
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the cluster label.
97
+ # @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
98
+ def predict(x)
99
+ x = check_convert_sample_array(x)
100
+ assign_cluster(x)
101
+ end
102
+
103
+ # Analysis clusters and assign samples to clusters.
104
+ #
105
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
106
+ # @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
107
+ def fit_predict(x)
108
+ x = check_convert_sample_array(x)
109
+ fit(x)
110
+ predict(x)
111
+ end
112
+
113
+ private
114
+
115
+ def assign_cluster(x)
116
+ distance_matrix = PairwiseMetric.euclidean_distance(x, @cluster_centers)
117
+ distance_matrix.min_index(axis: 1) - Numo::Int32[*0.step(distance_matrix.size - 1, @cluster_centers.shape[0])]
118
+ end
119
+
120
+ def init_cluster_centers(x, sub_rng)
121
+ # random initialize
122
+ n_samples = x.shape[0]
123
+ rand_id = Array(0...n_samples).sample(@params[:n_clusters], random: sub_rng)
124
+ @cluster_centers = x[rand_id, true].dup
125
+ return unless @params[:init] == 'k-means++'
126
+
127
+ # k-means++ initialize
128
+ (1...@params[:n_clusters]).each do |n|
129
+ distance_matrix = PairwiseMetric.euclidean_distance(x, @cluster_centers[0...n, true])
130
+ min_distances = distance_matrix.flatten[distance_matrix.min_index(axis: 1)]
131
+ probs = min_distances**2 / (min_distances**2).sum
132
+ cum_probs = probs.cumsum
133
+ selected_id = cum_probs.gt(sub_rng.rand).where.to_a.first
134
+ @cluster_centers[n, true] = x[selected_id, true].dup
135
+ end
136
+ end
137
+ end
138
+ end
139
+ end