rumale 0.18.7 → 0.20.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (42) hide show
  1. checksums.yaml +4 -4
  2. data/.rubocop.yml +66 -1
  3. data/CHANGELOG.md +46 -0
  4. data/Gemfile +2 -0
  5. data/README.md +5 -36
  6. data/lib/rumale.rb +5 -10
  7. data/lib/rumale/clustering/hdbscan.rb +1 -1
  8. data/lib/rumale/clustering/k_means.rb +1 -1
  9. data/lib/rumale/clustering/k_medoids.rb +1 -1
  10. data/lib/rumale/clustering/mini_batch_k_means.rb +139 -0
  11. data/lib/rumale/dataset.rb +3 -3
  12. data/lib/rumale/decomposition/pca.rb +23 -5
  13. data/lib/rumale/feature_extraction/feature_hasher.rb +14 -1
  14. data/lib/rumale/feature_extraction/tfidf_transformer.rb +113 -0
  15. data/lib/rumale/kernel_approximation/nystroem.rb +1 -1
  16. data/lib/rumale/kernel_machine/kernel_svc.rb +1 -1
  17. data/lib/rumale/linear_model/base_sgd.rb +1 -1
  18. data/lib/rumale/metric_learning/neighbourhood_component_analysis.rb +13 -1
  19. data/lib/rumale/model_selection/cross_validation.rb +3 -2
  20. data/lib/rumale/model_selection/k_fold.rb +1 -1
  21. data/lib/rumale/model_selection/shuffle_split.rb +1 -1
  22. data/lib/rumale/multiclass/one_vs_rest_classifier.rb +2 -2
  23. data/lib/rumale/nearest_neighbors/vp_tree.rb +1 -1
  24. data/lib/rumale/neural_network/adam.rb +1 -1
  25. data/lib/rumale/neural_network/base_mlp.rb +1 -1
  26. data/lib/rumale/preprocessing/binarizer.rb +60 -0
  27. data/lib/rumale/preprocessing/l1_normalizer.rb +62 -0
  28. data/lib/rumale/preprocessing/l2_normalizer.rb +2 -1
  29. data/lib/rumale/preprocessing/max_normalizer.rb +62 -0
  30. data/lib/rumale/version.rb +1 -1
  31. data/rumale.gemspec +1 -3
  32. metadata +11 -44
  33. data/lib/rumale/linear_model/base_linear_model.rb +0 -101
  34. data/lib/rumale/optimizer/ada_grad.rb +0 -39
  35. data/lib/rumale/optimizer/adam.rb +0 -53
  36. data/lib/rumale/optimizer/nadam.rb +0 -62
  37. data/lib/rumale/optimizer/rmsprop.rb +0 -47
  38. data/lib/rumale/optimizer/sgd.rb +0 -43
  39. data/lib/rumale/optimizer/yellow_fin.rb +0 -101
  40. data/lib/rumale/polynomial_model/base_factorization_machine.rb +0 -121
  41. data/lib/rumale/polynomial_model/factorization_machine_classifier.rb +0 -215
  42. data/lib/rumale/polynomial_model/factorization_machine_regressor.rb +0 -129
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 5e3069531e5acbdaab178769d20684a0fa260e29f6c39a645632b903fff8cce0
4
- data.tar.gz: d39c7e61a20b1bce23ccbb9d809bb06f1babce101f509c78c6da2a64e95f180f
3
+ metadata.gz: 358515f8785eb3de2e6571a957ca76cece6b774bb022c1a0951c92d44ab422b4
4
+ data.tar.gz: '0289b7eb382cd3300845412af0fd43626f4f827bb719083c879b574e3ab37eb0'
5
5
  SHA512:
6
- metadata.gz: eb9077b26d63f153eefd4c68ea57083e12b6a465d06864da8b24a3f9d2aff907b8de350ade5c96e8e9ec28997424839b91ac884d787a7bff7c2a44d212addd81
7
- data.tar.gz: 7d94c6d80e16ed405f87a7c777b4922863e42922a5b33046df4bc42d9daa5f5243ebb6c5492cb2d20120215cdce9a5c4b0ac3156012bf38cf91e7717b2c51c22
6
+ metadata.gz: f03fc0f27f99ed4acea3fb7d7bf34017c1dbf923b20dabc9a78d6d44f0b151bc9dc78ba24d122f81607a43fd1852e398a603b75b87656a2f79109f87c0db0d98
7
+ data.tar.gz: 69f6b8892f6bfb4c43706513245c3fba687dcb6a347c1c5185a70d5e45a024b2848a019bfae48726e1f49212878e8d6d67c811ec5f4a990fdbb3a2841efdfe9b
@@ -24,6 +24,15 @@ Style/HashTransformKeys:
24
24
  Style/HashTransformValues:
25
25
  Enabled: true
26
26
 
27
+ Lint/DeprecatedOpenSSLConstant:
28
+ Enabled: true
29
+
30
+ Lint/DuplicateElsifCondition:
31
+ Enabled: true
32
+
33
+ Lint/MixedRegexpCaptureTypes:
34
+ Enabled: true
35
+
27
36
  Lint/RaiseException:
28
37
  Enabled: true
29
38
 
@@ -34,7 +43,6 @@ Layout/LineLength:
34
43
  Max: 145
35
44
  IgnoredPatterns: ['(\A|\s)#']
36
45
 
37
-
38
46
  Metrics/ModuleLength:
39
47
  Max: 200
40
48
 
@@ -70,15 +78,48 @@ Naming/MethodParameterName:
70
78
  Naming/ConstantName:
71
79
  Enabled: false
72
80
 
81
+ Style/AccessorGrouping:
82
+ Enabled: true
83
+
84
+ Style/ArrayCoercion:
85
+ Enabled: true
86
+
87
+ Style/BisectedAttrAccessor:
88
+ Enabled: true
89
+
90
+ Style/CaseLikeIf:
91
+ Enabled: true
92
+
73
93
  Style/ExponentialNotation:
74
94
  Enabled: true
75
95
 
76
96
  Style/FormatStringToken:
77
97
  Enabled: false
78
98
 
99
+ Style/HashAsLastArrayItem:
100
+ Enabled: true
101
+
102
+ Style/HashLikeCase:
103
+ Enabled: true
104
+
79
105
  Style/NumericLiterals:
80
106
  Enabled: false
81
107
 
108
+ Style/RedundantAssignment:
109
+ Enabled: true
110
+
111
+ Style/RedundantFetchBlock:
112
+ Enabled: true
113
+
114
+ Style/RedundantFileExtensionInRequire:
115
+ Enabled: true
116
+
117
+ Style/RedundantRegexpCharacterClass:
118
+ Enabled: true
119
+
120
+ Style/RedundantRegexpEscape:
121
+ Enabled: true
122
+
82
123
  Style/SlicingWithRange:
83
124
  Enabled: true
84
125
 
@@ -91,6 +132,30 @@ Layout/EmptyLinesAroundAttributeAccessor:
91
132
  Layout/SpaceAroundMethodCallOperator:
92
133
  Enabled: true
93
134
 
135
+ Performance/AncestorsInclude:
136
+ Enabled: true
137
+
138
+ Performance/BigDecimalWithNumericArgument:
139
+ Enabled: true
140
+
141
+ Performance/RedundantSortBlock:
142
+ Enabled: true
143
+
144
+ Performance/RedundantStringChars:
145
+ Enabled: true
146
+
147
+ Performance/ReverseFirst:
148
+ Enabled: true
149
+
150
+ Performance/SortReverse:
151
+ Enabled: true
152
+
153
+ Performance/Squeeze:
154
+ Enabled: true
155
+
156
+ Performance/StringInclude:
157
+ Enabled: true
158
+
94
159
  RSpec/MultipleExpectations:
95
160
  Enabled: false
96
161
 
@@ -1,3 +1,49 @@
1
+ # 0.20.0
2
+ ## Breaking changes
3
+ - Delete deprecated estimators such as PolynomialModel, Optimizer, and BaseLinearModel.
4
+
5
+ # 0.19.3
6
+ - Add preprocessing class for [Binarizer](https://yoshoku.github.io/rumale/doc/Rumale/Preprocessing/Binarizer.html)
7
+ - Add preprocessing class for [MaxNormalizer](https://yoshoku.github.io/rumale/doc/Rumale/Preprocessing/MaxNormalizer.html)
8
+ - Refactor some codes with Rubocop.
9
+
10
+ # 0.19.2
11
+ - Fix L2Normalizer to avoid zero divide.
12
+ - Add preprocssing class for [L1Normalizer](https://yoshoku.github.io/rumale/doc/Rumale/Preprocessing/L1Normalizer.html).
13
+ - Add transformer class for [TfidfTransformer](https://yoshoku.github.io/rumale/doc/Rumale/FeatureExtraction/TfidfTransformer.html).
14
+
15
+ # 0.19.1
16
+ - Add cluster analysis class for [mini-batch K-Means](https://yoshoku.github.io/rumale/doc/Rumale/Clustering/MiniBatchKMeans.html).
17
+ - Fix some typos.
18
+
19
+ # 0.19.0
20
+ ## Breaking changes
21
+ - Change mmh3 and mopti gem to non-runtime dependent library.
22
+ - The mmh3 gem is used in [FeatureHasher](https://yoshoku.github.io/rumale/doc/Rumale/FeatureExtraction/FeatureHasher.html).
23
+ You only need to require mmh3 gem when using FeatureHasher.
24
+ ```ruby
25
+ require 'mmh3'
26
+ require 'rumale'
27
+
28
+ encoder = Rumale::FeatureExtraction::FeatureHasher.new
29
+ ```
30
+ - The mopti gem is used in [NeighbourhoodComponentAnalysis](https://yoshoku.github.io/rumale/doc/Rumale/MetricLearning/NeighbourhoodComponentAnalysis.html).
31
+ You only need to require mopti gem when using NeighbourhoodComponentAnalysis.
32
+ ```ruby
33
+ require 'mopti'
34
+ require 'rumale'
35
+
36
+ transformer = Rumale::MetricLearning::NeighbourhoodComponentAnalysis.new
37
+ ```
38
+ - Change the default value of solver parameter on [PCA](https://yoshoku.github.io/rumale/doc/Rumale/Decomposition/PCA.html) to 'auto'.
39
+ If Numo::Linalg is loaded, 'evd' is selected for the solver, otherwise 'fpt' is selected.
40
+ - Deprecate [PolynomialModel](https://yoshoku.github.io/rumale/doc/Rumale/PolynomialModel.html), [Optimizer](https://yoshoku.github.io/rumale/doc/Rumale/Optimizer.html), and the estimators contained in them. They will be deleted in version 0.20.0.
41
+ - Many machine learning libraries do not contain factorization machine algorithms, they are provided by another compatible library.
42
+ In addition, there are no plans to implement estimators in PolynomialModel.
43
+ Thus, the author decided to deprecate PolynomialModel.
44
+ - Currently, the Optimizer classes are only used by PolynomialModel estimators.
45
+ Therefore, they have been deprecated together with PolynomialModel.
46
+
1
47
  # 0.18.7
2
48
  - Fix to convert target_name to string array in [classification_report method](https://yoshoku.github.io/rumale/doc/Rumale/EvaluationMeasure.html#classification_report-class_method).
3
49
  - Refactor some codes with Rubocop.
data/Gemfile CHANGED
@@ -4,6 +4,8 @@ source 'https://rubygems.org'
4
4
  gemspec
5
5
 
6
6
  gem 'coveralls', '~> 0.8'
7
+ gem 'mmh3', '>= 1.0'
8
+ gem 'mopti', '>= 0.1.0'
7
9
  gem 'numo-linalg', '>= 0.1.4'
8
10
  gem 'parallel', '>= 1.17.0'
9
11
  gem 'rake', '~> 12.0'
data/README.md CHANGED
@@ -11,7 +11,7 @@
11
11
  Rumale (**Ru**by **ma**chine **le**arning) is a machine learning library in Ruby.
12
12
  Rumale provides machine learning algorithms with interfaces similar to Scikit-Learn in Python.
13
13
  Rumale supports Support Vector Machine,
14
- Logistic Regression, Ridge, Lasso, Factorization Machine,
14
+ Logistic Regression, Ridge, Lasso,
15
15
  Multi-layer Perceptron,
16
16
  Naive Bayes, Decision Tree, Gradient Tree Boosting, Random Forest,
17
17
  K-Means, Gaussian Mixture Model, DBSCAN, Spectral Clustering,
@@ -42,39 +42,7 @@ Or install it yourself as:
42
42
 
43
43
  ## Usage
44
44
 
45
- ### Example 1. XOR data
46
- First, let's classify simple xor data.
47
-
48
- ```ruby
49
- require 'rumale'
50
-
51
- # Prepare XOR data.
52
- samples = [[0, 0], [0, 1], [1, 0], [1, 1]]
53
- labels = [0, 1, 1, 0]
54
-
55
- # Train classifier with nearest neighbor rule.
56
- estimator = Rumale::NearestNeighbors::KNeighborsClassifier.new(n_neighbors: 1)
57
- estimator.fit(samples, labels)
58
-
59
- # Predict labels.
60
- p labels
61
- p estimator.predict(samples)
62
- ```
63
-
64
- Execution of the above script result in the following.
65
-
66
- ```ruby
67
- [0, 1, 1, 0]
68
- Numo::Int32#shape=[4]
69
- [0, 1, 1, 0]
70
- ```
71
-
72
- The basic usage of Rumale is to first train the model with the fit method
73
- and then estimate with the predict method.
74
- In addition, Rumale recommends using arrays such as feature vectors and labels with
75
- [Numo::NArray](https://github.com/ruby-numo/numo-narray).
76
-
77
- ### Example 2. Pendigits dataset classification
45
+ ### Example 1. Pendigits dataset classification
78
46
 
79
47
  Rumale provides function loading libsvm format dataset file.
80
48
  We start by downloading the pendigits dataset from LIBSVM Data web site.
@@ -137,7 +105,7 @@ $ ruby test.rb
137
105
  Accuracy: 98.7%
138
106
  ```
139
107
 
140
- ### Example 3. Cross-validation
108
+ ### Example 2. Cross-validation
141
109
 
142
110
  ```ruby
143
111
  require 'rumale'
@@ -168,7 +136,7 @@ $ ruby cross_validation.rb
168
136
  5-CV mean log-loss: 0.355
169
137
  ```
170
138
 
171
- ### Example 4. Pipeline
139
+ ### Example 3. Pipeline
172
140
 
173
141
  ```ruby
174
142
  require 'rumale'
@@ -203,6 +171,7 @@ $ ruby pipeline.rb
203
171
  ## Speed up
204
172
 
205
173
  ### Numo::Linalg
174
+ Rumale uses [Numo::NArray](https://github.com/ruby-numo/numo-narray) for typed arrays.
206
175
  Loading the [Numo::Linalg](https://github.com/ruby-numo/numo-linalg) allows to perform matrix product of Numo::NArray using BLAS libraries.
207
176
  For example, using the [OpenBLAS](https://github.com/xianyi/OpenBLAS) speeds up many estimators in Rumale.
208
177
 
@@ -18,17 +18,10 @@ require 'rumale/base/cluster_analyzer'
18
18
  require 'rumale/base/transformer'
19
19
  require 'rumale/base/splitter'
20
20
  require 'rumale/base/evaluator'
21
- require 'rumale/optimizer/sgd'
22
- require 'rumale/optimizer/ada_grad'
23
- require 'rumale/optimizer/rmsprop'
24
- require 'rumale/optimizer/adam'
25
- require 'rumale/optimizer/nadam'
26
- require 'rumale/optimizer/yellow_fin'
27
21
  require 'rumale/pipeline/pipeline'
28
22
  require 'rumale/pipeline/feature_union'
29
23
  require 'rumale/kernel_approximation/rbf'
30
24
  require 'rumale/kernel_approximation/nystroem'
31
- require 'rumale/linear_model/base_linear_model'
32
25
  require 'rumale/linear_model/base_sgd'
33
26
  require 'rumale/linear_model/svc'
34
27
  require 'rumale/linear_model/svr'
@@ -41,9 +34,6 @@ require 'rumale/kernel_machine/kernel_svc'
41
34
  require 'rumale/kernel_machine/kernel_pca'
42
35
  require 'rumale/kernel_machine/kernel_fda'
43
36
  require 'rumale/kernel_machine/kernel_ridge'
44
- require 'rumale/polynomial_model/base_factorization_machine'
45
- require 'rumale/polynomial_model/factorization_machine_classifier'
46
- require 'rumale/polynomial_model/factorization_machine_regressor'
47
37
  require 'rumale/multiclass/one_vs_rest_classifier'
48
38
  require 'rumale/nearest_neighbors/vp_tree'
49
39
  require 'rumale/nearest_neighbors/k_neighbors_classifier'
@@ -70,6 +60,7 @@ require 'rumale/ensemble/random_forest_regressor'
70
60
  require 'rumale/ensemble/extra_trees_classifier'
71
61
  require 'rumale/ensemble/extra_trees_regressor'
72
62
  require 'rumale/clustering/k_means'
63
+ require 'rumale/clustering/mini_batch_k_means'
73
64
  require 'rumale/clustering/k_medoids'
74
65
  require 'rumale/clustering/gaussian_mixture'
75
66
  require 'rumale/clustering/dbscan'
@@ -92,7 +83,10 @@ require 'rumale/neural_network/mlp_regressor'
92
83
  require 'rumale/neural_network/mlp_classifier'
93
84
  require 'rumale/feature_extraction/hash_vectorizer'
94
85
  require 'rumale/feature_extraction/feature_hasher'
86
+ require 'rumale/feature_extraction/tfidf_transformer'
95
87
  require 'rumale/preprocessing/l2_normalizer'
88
+ require 'rumale/preprocessing/l1_normalizer'
89
+ require 'rumale/preprocessing/max_normalizer'
96
90
  require 'rumale/preprocessing/min_max_scaler'
97
91
  require 'rumale/preprocessing/max_abs_scaler'
98
92
  require 'rumale/preprocessing/standard_scaler'
@@ -101,6 +95,7 @@ require 'rumale/preprocessing/label_binarizer'
101
95
  require 'rumale/preprocessing/label_encoder'
102
96
  require 'rumale/preprocessing/one_hot_encoder'
103
97
  require 'rumale/preprocessing/ordinal_encoder'
98
+ require 'rumale/preprocessing/binarizer'
104
99
  require 'rumale/preprocessing/polynomial_features'
105
100
  require 'rumale/model_selection/k_fold'
106
101
  require 'rumale/model_selection/stratified_k_fold'
@@ -232,7 +232,7 @@ module Rumale
232
232
  end
233
233
 
234
234
  def flatten(tree, stabilities)
235
- node_ids = stabilities.keys.sort { |a, b| b <=> a }.slice(0, stabilities.size - 1)
235
+ node_ids = stabilities.keys.sort.reverse.slice(0, stabilities.size - 1)
236
236
 
237
237
  cluster_tree = tree.select { |edge| edge.n_elements > 1 }
238
238
  is_cluster = node_ids.each_with_object({}) { |n_id, h| h[n_id] = true }
@@ -103,7 +103,7 @@ module Rumale
103
103
  # random initialize
104
104
  n_samples = x.shape[0]
105
105
  sub_rng = @rng.dup
106
- rand_id = [*0...n_samples].sample(@params[:n_clusters], random: sub_rng)
106
+ rand_id = Array(0...n_samples).sample(@params[:n_clusters], random: sub_rng)
107
107
  @cluster_centers = x[rand_id, true].dup
108
108
  return unless @params[:init] == 'k-means++'
109
109
 
@@ -124,7 +124,7 @@ module Rumale
124
124
  # random initialize
125
125
  n_samples = distance_mat.shape[0]
126
126
  sub_rng = @rng.dup
127
- @medoid_ids = Numo::Int32.asarray([*0...n_samples].sample(@params[:n_clusters], random: sub_rng))
127
+ @medoid_ids = Numo::Int32.asarray(Array(0...n_samples).sample(@params[:n_clusters], random: sub_rng))
128
128
  return unless @params[:init] == 'k-means++'
129
129
 
130
130
  # k-means++ initialize
@@ -0,0 +1,139 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/base_estimator'
4
+ require 'rumale/base/cluster_analyzer'
5
+ require 'rumale/pairwise_metric'
6
+
7
+ module Rumale
8
+ module Clustering
9
+ # MniBatchKMeans is a class that implements K-Means cluster analysis
10
+ # with mini-batch stochastic gradient descent (SGD).
11
+ #
12
+ # @example
13
+ # analyzer = Rumale::Clustering::MiniBatchKMeans.new(n_clusters: 10, max_iter: 50, batch_size: 50, random_seed: 1)
14
+ # cluster_labels = analyzer.fit_predict(samples)
15
+ #
16
+ # *Reference*
17
+ # - Sculley, D., "Web-scale k-means clustering," Proc. WWW'10, pp. 1177--1178, 2010.
18
+ class MiniBatchKMeans
19
+ include Base::BaseEstimator
20
+ include Base::ClusterAnalyzer
21
+
22
+ # Return the centroids.
23
+ # @return [Numo::DFloat] (shape: [n_clusters, n_features])
24
+ attr_reader :cluster_centers
25
+
26
+ # Return the random generator.
27
+ # @return [Random]
28
+ attr_reader :rng
29
+
30
+ # Create a new cluster analyzer with K-Means method with mini-batch SGD.
31
+ #
32
+ # @param n_clusters [Integer] The number of clusters.
33
+ # @param init [String] The initialization method for centroids ('random' or 'k-means++').
34
+ # @param max_iter [Integer] The maximum number of iterations.
35
+ # @param batch_size [Integer] The size of the mini batches.
36
+ # @param tol [Float] The tolerance of termination criterion.
37
+ # @param random_seed [Integer] The seed value using to initialize the random generator.
38
+ def initialize(n_clusters: 8, init: 'k-means++', max_iter: 100, batch_size: 100, tol: 1.0e-4, random_seed: nil)
39
+ check_params_numeric(n_clusters: n_clusters, max_iter: max_iter, batch_size: batch_size, tol: tol)
40
+ check_params_string(init: init)
41
+ check_params_numeric_or_nil(random_seed: random_seed)
42
+ check_params_positive(n_clusters: n_clusters, max_iter: max_iter)
43
+ @params = {}
44
+ @params[:n_clusters] = n_clusters
45
+ @params[:init] = init == 'random' ? 'random' : 'k-means++'
46
+ @params[:max_iter] = max_iter
47
+ @params[:batch_size] = batch_size
48
+ @params[:tol] = tol
49
+ @params[:random_seed] = random_seed
50
+ @params[:random_seed] ||= srand
51
+ @cluster_centers = nil
52
+ @rng = Random.new(@params[:random_seed])
53
+ end
54
+
55
+ # Analysis clusters with given training data.
56
+ #
57
+ # @overload fit(x) -> MiniBatchKMeans
58
+ #
59
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
60
+ # @return [KMeans] The learned cluster analyzer itself.
61
+ def fit(x, _y = nil)
62
+ x = check_convert_sample_array(x)
63
+ # initialization.
64
+ n_samples = x.shape[0]
65
+ update_counter = Numo::Int32.zeros(@params[:n_clusters])
66
+ sub_rng = @rng.dup
67
+ init_cluster_centers(x, sub_rng)
68
+ # optimization with mini-batch sgd.
69
+ @params[:max_iter].times do |_t|
70
+ sample_ids = Array(0...n_samples).shuffle(random: sub_rng)
71
+ old_centers = @cluster_centers.dup
72
+ until (subset_ids = sample_ids.shift(@params[:batch_size])).empty?
73
+ # sub sampling
74
+ sub_x = x[subset_ids, true]
75
+ # assign nearest centroids
76
+ cluster_labels = assign_cluster(sub_x)
77
+ # update centroids
78
+ @params[:n_clusters].times do |c|
79
+ assigned_bits = cluster_labels.eq(c)
80
+ next unless assigned_bits.count.positive?
81
+
82
+ update_counter[c] += 1
83
+ learning_rate = 1.fdiv(update_counter[c])
84
+ update = sub_x[assigned_bits.where, true].mean(axis: 0)
85
+ @cluster_centers[c, true] = (1 - learning_rate) * @cluster_centers[c, true] + learning_rate * update
86
+ end
87
+ end
88
+ error = Numo::NMath.sqrt(((old_centers - @cluster_centers)**2).sum(axis: 1)).mean
89
+ break if error <= @params[:tol]
90
+ end
91
+ self
92
+ end
93
+
94
+ # Predict cluster labels for samples.
95
+ #
96
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the cluster label.
97
+ # @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
98
+ def predict(x)
99
+ x = check_convert_sample_array(x)
100
+ assign_cluster(x)
101
+ end
102
+
103
+ # Analysis clusters and assign samples to clusters.
104
+ #
105
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
106
+ # @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
107
+ def fit_predict(x)
108
+ x = check_convert_sample_array(x)
109
+ fit(x)
110
+ predict(x)
111
+ end
112
+
113
+ private
114
+
115
+ def assign_cluster(x)
116
+ distance_matrix = PairwiseMetric.euclidean_distance(x, @cluster_centers)
117
+ distance_matrix.min_index(axis: 1) - Numo::Int32[*0.step(distance_matrix.size - 1, @cluster_centers.shape[0])]
118
+ end
119
+
120
+ def init_cluster_centers(x, sub_rng)
121
+ # random initialize
122
+ n_samples = x.shape[0]
123
+ rand_id = Array(0...n_samples).sample(@params[:n_clusters], random: sub_rng)
124
+ @cluster_centers = x[rand_id, true].dup
125
+ return unless @params[:init] == 'k-means++'
126
+
127
+ # k-means++ initialize
128
+ (1...@params[:n_clusters]).each do |n|
129
+ distance_matrix = PairwiseMetric.euclidean_distance(x, @cluster_centers[0...n, true])
130
+ min_distances = distance_matrix.flatten[distance_matrix.min_index(axis: 1)]
131
+ probs = min_distances**2 / (min_distances**2).sum
132
+ cum_probs = probs.cumsum
133
+ selected_id = cum_probs.gt(sub_rng.rand).where.to_a.first
134
+ @cluster_centers[n, true] = x[selected_id, true].dup
135
+ end
136
+ end
137
+ end
138
+ end
139
+ end