rumale 0.8.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. checksums.yaml +7 -0
  2. data/.coveralls.yml +1 -0
  3. data/.gitignore +20 -0
  4. data/.rspec +3 -0
  5. data/.rubocop.yml +47 -0
  6. data/.rubocop_todo.yml +58 -0
  7. data/.travis.yml +13 -0
  8. data/CHANGELOG.md +2 -0
  9. data/CODE_OF_CONDUCT.md +74 -0
  10. data/Gemfile +4 -0
  11. data/LICENSE.txt +23 -0
  12. data/README.md +175 -0
  13. data/Rakefile +6 -0
  14. data/bin/console +14 -0
  15. data/bin/setup +8 -0
  16. data/lib/rumale.rb +70 -0
  17. data/lib/rumale/base/base_estimator.rb +13 -0
  18. data/lib/rumale/base/classifier.rb +36 -0
  19. data/lib/rumale/base/cluster_analyzer.rb +31 -0
  20. data/lib/rumale/base/evaluator.rb +17 -0
  21. data/lib/rumale/base/regressor.rb +36 -0
  22. data/lib/rumale/base/splitter.rb +21 -0
  23. data/lib/rumale/base/transformer.rb +22 -0
  24. data/lib/rumale/clustering/dbscan.rb +125 -0
  25. data/lib/rumale/clustering/k_means.rb +138 -0
  26. data/lib/rumale/dataset.rb +110 -0
  27. data/lib/rumale/decomposition/nmf.rb +141 -0
  28. data/lib/rumale/decomposition/pca.rb +148 -0
  29. data/lib/rumale/ensemble/ada_boost_classifier.rb +196 -0
  30. data/lib/rumale/ensemble/ada_boost_regressor.rb +178 -0
  31. data/lib/rumale/ensemble/random_forest_classifier.rb +180 -0
  32. data/lib/rumale/ensemble/random_forest_regressor.rb +141 -0
  33. data/lib/rumale/evaluation_measure/accuracy.rb +29 -0
  34. data/lib/rumale/evaluation_measure/f_score.rb +50 -0
  35. data/lib/rumale/evaluation_measure/log_loss.rb +45 -0
  36. data/lib/rumale/evaluation_measure/mean_absolute_error.rb +29 -0
  37. data/lib/rumale/evaluation_measure/mean_squared_error.rb +29 -0
  38. data/lib/rumale/evaluation_measure/normalized_mutual_information.rb +62 -0
  39. data/lib/rumale/evaluation_measure/precision.rb +50 -0
  40. data/lib/rumale/evaluation_measure/precision_recall.rb +91 -0
  41. data/lib/rumale/evaluation_measure/purity.rb +40 -0
  42. data/lib/rumale/evaluation_measure/r2_score.rb +43 -0
  43. data/lib/rumale/evaluation_measure/recall.rb +50 -0
  44. data/lib/rumale/kernel_approximation/rbf.rb +121 -0
  45. data/lib/rumale/kernel_machine/kernel_svc.rb +193 -0
  46. data/lib/rumale/linear_model/base_linear_model.rb +89 -0
  47. data/lib/rumale/linear_model/lasso.rb +136 -0
  48. data/lib/rumale/linear_model/linear_regression.rb +110 -0
  49. data/lib/rumale/linear_model/logistic_regression.rb +159 -0
  50. data/lib/rumale/linear_model/ridge.rb +110 -0
  51. data/lib/rumale/linear_model/svc.rb +183 -0
  52. data/lib/rumale/linear_model/svr.rb +122 -0
  53. data/lib/rumale/model_selection/cross_validation.rb +123 -0
  54. data/lib/rumale/model_selection/grid_search_cv.rb +247 -0
  55. data/lib/rumale/model_selection/k_fold.rb +76 -0
  56. data/lib/rumale/model_selection/stratified_k_fold.rb +94 -0
  57. data/lib/rumale/multiclass/one_vs_rest_classifier.rb +100 -0
  58. data/lib/rumale/naive_bayes/naive_bayes.rb +315 -0
  59. data/lib/rumale/nearest_neighbors/k_neighbors_classifier.rb +111 -0
  60. data/lib/rumale/nearest_neighbors/k_neighbors_regressor.rb +93 -0
  61. data/lib/rumale/optimizer/nadam.rb +90 -0
  62. data/lib/rumale/optimizer/rmsprop.rb +69 -0
  63. data/lib/rumale/optimizer/sgd.rb +65 -0
  64. data/lib/rumale/optimizer/yellow_fin.rb +144 -0
  65. data/lib/rumale/pairwise_metric.rb +91 -0
  66. data/lib/rumale/pipeline/pipeline.rb +197 -0
  67. data/lib/rumale/polynomial_model/base_factorization_machine.rb +99 -0
  68. data/lib/rumale/polynomial_model/factorization_machine_classifier.rb +197 -0
  69. data/lib/rumale/polynomial_model/factorization_machine_regressor.rb +131 -0
  70. data/lib/rumale/preprocessing/l2_normalizer.rb +62 -0
  71. data/lib/rumale/preprocessing/label_encoder.rb +94 -0
  72. data/lib/rumale/preprocessing/min_max_scaler.rb +92 -0
  73. data/lib/rumale/preprocessing/one_hot_encoder.rb +98 -0
  74. data/lib/rumale/preprocessing/standard_scaler.rb +86 -0
  75. data/lib/rumale/probabilistic_output.rb +112 -0
  76. data/lib/rumale/tree/base_decision_tree.rb +153 -0
  77. data/lib/rumale/tree/decision_tree_classifier.rb +163 -0
  78. data/lib/rumale/tree/decision_tree_regressor.rb +135 -0
  79. data/lib/rumale/tree/node.rb +70 -0
  80. data/lib/rumale/utils.rb +37 -0
  81. data/lib/rumale/validation.rb +79 -0
  82. data/lib/rumale/values.rb +13 -0
  83. data/lib/rumale/version.rb +6 -0
  84. data/rumale.gemspec +41 -0
  85. metadata +204 -0
@@ -0,0 +1,14 @@
1
+ #!/usr/bin/env ruby
2
+
3
+ require 'bundler/setup'
4
+ require 'rumale'
5
+
6
+ # You can add fixtures and/or initialization code here to make experimenting
7
+ # with your gem easier. You can also use a different console, if you like.
8
+
9
+ # (If you use this, don't forget to add pry to your Gemfile!)
10
+ # require 'pry'
11
+ # Pry.start
12
+
13
+ require 'irb'
14
+ IRB.start(__FILE__)
@@ -0,0 +1,8 @@
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+ IFS=$'\n\t'
4
+ set -vx
5
+
6
+ bundle install
7
+
8
+ # Do any other automated setup that you need to do here
@@ -0,0 +1,70 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'numo/narray'
4
+
5
+ require 'rumale/version'
6
+ require 'rumale/validation'
7
+ require 'rumale/values'
8
+ require 'rumale/utils'
9
+ require 'rumale/pairwise_metric'
10
+ require 'rumale/dataset'
11
+ require 'rumale/probabilistic_output'
12
+ require 'rumale/base/base_estimator'
13
+ require 'rumale/base/classifier'
14
+ require 'rumale/base/regressor'
15
+ require 'rumale/base/cluster_analyzer'
16
+ require 'rumale/base/transformer'
17
+ require 'rumale/base/splitter'
18
+ require 'rumale/base/evaluator'
19
+ require 'rumale/optimizer/sgd'
20
+ require 'rumale/optimizer/rmsprop'
21
+ require 'rumale/optimizer/nadam'
22
+ require 'rumale/optimizer/yellow_fin'
23
+ require 'rumale/pipeline/pipeline'
24
+ require 'rumale/kernel_approximation/rbf'
25
+ require 'rumale/linear_model/base_linear_model'
26
+ require 'rumale/linear_model/svc'
27
+ require 'rumale/linear_model/svr'
28
+ require 'rumale/linear_model/logistic_regression'
29
+ require 'rumale/linear_model/linear_regression'
30
+ require 'rumale/linear_model/ridge'
31
+ require 'rumale/linear_model/lasso'
32
+ require 'rumale/kernel_machine/kernel_svc'
33
+ require 'rumale/polynomial_model/base_factorization_machine'
34
+ require 'rumale/polynomial_model/factorization_machine_classifier'
35
+ require 'rumale/polynomial_model/factorization_machine_regressor'
36
+ require 'rumale/multiclass/one_vs_rest_classifier'
37
+ require 'rumale/nearest_neighbors/k_neighbors_classifier'
38
+ require 'rumale/nearest_neighbors/k_neighbors_regressor'
39
+ require 'rumale/naive_bayes/naive_bayes'
40
+ require 'rumale/tree/node'
41
+ require 'rumale/tree/base_decision_tree'
42
+ require 'rumale/tree/decision_tree_classifier'
43
+ require 'rumale/tree/decision_tree_regressor'
44
+ require 'rumale/ensemble/ada_boost_classifier'
45
+ require 'rumale/ensemble/ada_boost_regressor'
46
+ require 'rumale/ensemble/random_forest_classifier'
47
+ require 'rumale/ensemble/random_forest_regressor'
48
+ require 'rumale/clustering/k_means'
49
+ require 'rumale/clustering/dbscan'
50
+ require 'rumale/decomposition/pca'
51
+ require 'rumale/decomposition/nmf'
52
+ require 'rumale/preprocessing/l2_normalizer'
53
+ require 'rumale/preprocessing/min_max_scaler'
54
+ require 'rumale/preprocessing/standard_scaler'
55
+ require 'rumale/preprocessing/label_encoder'
56
+ require 'rumale/preprocessing/one_hot_encoder'
57
+ require 'rumale/model_selection/k_fold'
58
+ require 'rumale/model_selection/stratified_k_fold'
59
+ require 'rumale/model_selection/cross_validation'
60
+ require 'rumale/model_selection/grid_search_cv'
61
+ require 'rumale/evaluation_measure/accuracy'
62
+ require 'rumale/evaluation_measure/precision'
63
+ require 'rumale/evaluation_measure/recall'
64
+ require 'rumale/evaluation_measure/f_score'
65
+ require 'rumale/evaluation_measure/log_loss'
66
+ require 'rumale/evaluation_measure/r2_score'
67
+ require 'rumale/evaluation_measure/mean_squared_error'
68
+ require 'rumale/evaluation_measure/mean_absolute_error'
69
+ require 'rumale/evaluation_measure/purity'
70
+ require 'rumale/evaluation_measure/normalized_mutual_information'
@@ -0,0 +1,13 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Rumale
4
+ # This module consists of basic mix-in classes.
5
+ module Base
6
+ # Base module for all estimators in Rumale.
7
+ module BaseEstimator
8
+ # Return parameters about an estimator.
9
+ # @return [Hash]
10
+ attr_reader :params
11
+ end
12
+ end
13
+ end
@@ -0,0 +1,36 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/validation'
4
+ require 'rumale/evaluation_measure/accuracy'
5
+
6
+ module Rumale
7
+ module Base
8
+ # Module for all classifiers in Rumale.
9
+ module Classifier
10
+ include Validation
11
+
12
+ # An abstract method for fitting a model.
13
+ def fit
14
+ raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
15
+ end
16
+
17
+ # An abstract method for predicting labels.
18
+ def predict
19
+ raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
20
+ end
21
+
22
+ # Calculate the mean accuracy of the given testing data.
23
+ #
24
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) Testing data.
25
+ # @param y [Numo::Int32] (shape: [n_samples]) True labels for testing data.
26
+ # @return [Float] Mean accuracy
27
+ def score(x, y)
28
+ check_sample_array(x)
29
+ check_label_array(y)
30
+ check_sample_label_size(x, y)
31
+ evaluator = Rumale::EvaluationMeasure::Accuracy.new
32
+ evaluator.score(y, predict(x))
33
+ end
34
+ end
35
+ end
36
+ end
@@ -0,0 +1,31 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/validation'
4
+ require 'rumale/evaluation_measure/purity'
5
+
6
+ module Rumale
7
+ module Base
8
+ # Module for all clustering algorithms in Rumale.
9
+ module ClusterAnalyzer
10
+ include Validation
11
+
12
+ # An abstract method for analyzing clusters and predicting cluster indices.
13
+ def fit_predict
14
+ raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
15
+ end
16
+
17
+ # Calculate purity of clustering result.
18
+ #
19
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) Testing data.
20
+ # @param y [Numo::Int32] (shape: [n_samples]) True labels for testing data.
21
+ # @return [Float] Purity
22
+ def score(x, y)
23
+ check_sample_array(x)
24
+ check_label_array(y)
25
+ check_sample_label_size(x, y)
26
+ evaluator = Rumale::EvaluationMeasure::Purity.new
27
+ evaluator.score(y, fit_predict(x))
28
+ end
29
+ end
30
+ end
31
+ end
@@ -0,0 +1,17 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/validation'
4
+
5
+ module Rumale
6
+ module Base
7
+ # Module for all evaluation measures in Rumale.
8
+ module Evaluator
9
+ include Validation
10
+
11
+ # An abstract method for evaluation of model.
12
+ def score
13
+ raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
14
+ end
15
+ end
16
+ end
17
+ end
@@ -0,0 +1,36 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/validation'
4
+ require 'rumale/evaluation_measure/r2_score'
5
+
6
+ module Rumale
7
+ module Base
8
+ # Module for all regressors in Rumale.
9
+ module Regressor
10
+ include Validation
11
+
12
+ # An abstract method for fitting a model.
13
+ def fit
14
+ raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
15
+ end
16
+
17
+ # An abstract method for predicting labels.
18
+ def predict
19
+ raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
20
+ end
21
+
22
+ # Calculate the coefficient of determination for the given testing data.
23
+ #
24
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) Testing data.
25
+ # @param y [Numo::DFloat] (shape: [n_samples, n_outputs]) Target values for testing data.
26
+ # @return [Float] Coefficient of determination
27
+ def score(x, y)
28
+ check_sample_array(x)
29
+ check_tvalue_array(y)
30
+ check_sample_tvalue_size(x, y)
31
+ evaluator = Rumale::EvaluationMeasure::R2Score.new
32
+ evaluator.score(y, predict(x))
33
+ end
34
+ end
35
+ end
36
+ end
@@ -0,0 +1,21 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/validation'
4
+
5
+ module Rumale
6
+ module Base
7
+ # Module for all validation methods in Rumale.
8
+ module Splitter
9
+ include Validation
10
+
11
+ # Return the number of splits.
12
+ # @return [Integer]
13
+ attr_reader :n_splits
14
+
15
+ # An abstract method for splitting dataset.
16
+ def split
17
+ raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
18
+ end
19
+ end
20
+ end
21
+ end
@@ -0,0 +1,22 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/validation'
4
+
5
+ module Rumale
6
+ module Base
7
+ # Module for all transfomers in Rumale.
8
+ module Transformer
9
+ include Validation
10
+
11
+ # An abstract method for fitting a model.
12
+ def fit
13
+ raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
14
+ end
15
+
16
+ # An abstract method for fitting a model and transforming given data.
17
+ def fit_transform
18
+ raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
19
+ end
20
+ end
21
+ end
22
+ end
@@ -0,0 +1,125 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/base_estimator'
4
+ require 'rumale/base/cluster_analyzer'
5
+ require 'rumale/pairwise_metric'
6
+
7
+ module Rumale
8
+ module Clustering
9
+ # DBSCAN is a class that implements DBSCAN cluster analysis.
10
+ # The current implementation uses the Euclidean distance for analyzing the clusters.
11
+ #
12
+ # @example
13
+ # analyzer = Rumale::Clustering::DBSCAN.new(eps: 0.5, min_samples: 5)
14
+ # cluster_labels = analyzer.fit_predict(samples)
15
+ #
16
+ # *Reference*
17
+ # - M. Ester, H-P. Kriegel, J. Sander, and X. Xu, "A density-based algorithm for discovering clusters in large spatial databases with noise," Proc. KDD' 96, pp. 266--231, 1996.
18
+ class DBSCAN
19
+ include Base::BaseEstimator
20
+ include Base::ClusterAnalyzer
21
+
22
+ # Return the core sample indices.
23
+ # @return [Numo::Int32] (shape: [n_core_samples])
24
+ attr_reader :core_sample_ids
25
+
26
+ # Return the cluster labels. The negative cluster label indicates that the point is noise.
27
+ # @return [Numo::Int32] (shape: [n_samples])
28
+ attr_reader :labels
29
+
30
+ # Create a new cluster analyzer with DBSCAN method.
31
+ #
32
+ # @param eps [Float] The radius of neighborhood.
33
+ # @param min_samples [Integer] The number of neighbor samples to be used for the criterion whether a point is a core point.
34
+ def initialize(eps: 0.5, min_samples: 5)
35
+ check_params_float(eps: eps)
36
+ check_params_integer(min_samples: min_samples)
37
+ @params = {}
38
+ @params[:eps] = eps
39
+ @params[:min_samples] = min_samples
40
+ @core_sample_ids = nil
41
+ @labels = nil
42
+ end
43
+
44
+ # Analysis clusters with given training data.
45
+ #
46
+ # @overload fit(x) -> DBSCAN
47
+ #
48
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
49
+ # @return [DBSCAN] The learned cluster analyzer itself.
50
+ def fit(x, _y = nil)
51
+ check_sample_array(x)
52
+ partial_fit(x)
53
+ self
54
+ end
55
+
56
+ # Analysis clusters and assign samples to clusters.
57
+ #
58
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
59
+ # @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
60
+ def fit_predict(x)
61
+ check_sample_array(x)
62
+ partial_fit(x)
63
+ labels
64
+ end
65
+
66
+ # Dump marshal data.
67
+ # @return [Hash] The marshal data.
68
+ def marshal_dump
69
+ { params: @params,
70
+ core_sample_ids: @core_sample_ids,
71
+ labels: @labels }
72
+ end
73
+
74
+ # Load marshal data.
75
+ # @return [nil]
76
+ def marshal_load(obj)
77
+ @params = obj[:params]
78
+ @core_sample_ids = obj[:core_sample_ids]
79
+ @labels = obj[:labels]
80
+ nil
81
+ end
82
+
83
+ private
84
+
85
+ def partial_fit(x)
86
+ cluster_id = 0
87
+ n_samples = x.shape[0]
88
+ @core_sample_ids = []
89
+ @labels = Numo::Int32.zeros(n_samples) - 2
90
+ n_samples.times do |q|
91
+ next if @labels[q] >= -1
92
+ cluster_id += 1 if expand_cluster(x, q, cluster_id)
93
+ end
94
+ @core_sample_ids = Numo::Int32[*@core_sample_ids.flatten]
95
+ nil
96
+ end
97
+
98
+ def expand_cluster(x, query_id, cluster_id)
99
+ target_ids = region_query(x[query_id, true], x)
100
+ if target_ids.size < @params[:min_samples]
101
+ @labels[query_id] = -1
102
+ false
103
+ else
104
+ @labels[target_ids] = cluster_id
105
+ @core_sample_ids.push(target_ids.dup)
106
+ target_ids.delete(query_id)
107
+ while (m = target_ids.shift)
108
+ neighbor_ids = region_query(x[m, true], x)
109
+ next if neighbor_ids.size < @params[:min_samples]
110
+ neighbor_ids.each do |n|
111
+ target_ids.push(n) if @labels[n] < -1
112
+ @labels[n] = cluster_id if @labels[n] <= -1
113
+ end
114
+ end
115
+ true
116
+ end
117
+ end
118
+
119
+ def region_query(query, targets)
120
+ distance_arr = PairwiseMetric.euclidean_distance(query.expand_dims(0), targets)[0, true]
121
+ distance_arr.lt(@params[:eps]).where.to_a
122
+ end
123
+ end
124
+ end
125
+ end
@@ -0,0 +1,138 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/base_estimator'
4
+ require 'rumale/base/cluster_analyzer'
5
+ require 'rumale/pairwise_metric'
6
+
7
+ module Rumale
8
+ # This module consists of classes that implement cluster analysis methods.
9
+ module Clustering
10
+ # KMeans is a class that implements K-Means cluster analysis.
11
+ # The current implementation uses the Euclidean distance for analyzing the clusters.
12
+ #
13
+ # @example
14
+ # analyzer = Rumale::Clustering::KMeans.new(n_clusters: 10, max_iter: 50)
15
+ # cluster_labels = analyzer.fit_predict(samples)
16
+ #
17
+ # *Reference*
18
+ # - D. Arthur and S. Vassilvitskii, "k-means++: the advantages of careful seeding," Proc. SODA'07, pp. 1027--1035, 2007.
19
+ class KMeans
20
+ include Base::BaseEstimator
21
+ include Base::ClusterAnalyzer
22
+
23
+ # Return the centroids.
24
+ # @return [Numo::DFloat] (shape: [n_clusters, n_features])
25
+ attr_reader :cluster_centers
26
+
27
+ # Return the random generator.
28
+ # @return [Random]
29
+ attr_reader :rng
30
+
31
+ # Create a new cluster analyzer with K-Means method.
32
+ #
33
+ # @param n_clusters [Integer] The number of clusters.
34
+ # @param init [String] The initialization method for centroids ('random' or 'k-means++').
35
+ # @param max_iter [Integer] The maximum number of iterations.
36
+ # @param tol [Float] The tolerance of termination criterion.
37
+ # @param random_seed [Integer] The seed value using to initialize the random generator.
38
+ def initialize(n_clusters: 8, init: 'k-means++', max_iter: 50, tol: 1.0e-4, random_seed: nil)
39
+ check_params_integer(n_clusters: n_clusters, max_iter: max_iter)
40
+ check_params_float(tol: tol)
41
+ check_params_string(init: init)
42
+ check_params_type_or_nil(Integer, random_seed: random_seed)
43
+ check_params_positive(n_clusters: n_clusters, max_iter: max_iter)
44
+ @params = {}
45
+ @params[:n_clusters] = n_clusters
46
+ @params[:init] = init == 'random' ? 'random' : 'k-means++'
47
+ @params[:max_iter] = max_iter
48
+ @params[:tol] = tol
49
+ @params[:random_seed] = random_seed
50
+ @params[:random_seed] ||= srand
51
+ @cluster_centers = nil
52
+ @rng = Random.new(@params[:random_seed])
53
+ end
54
+
55
+ # Analysis clusters with given training data.
56
+ #
57
+ # @overload fit(x) -> KMeans
58
+ #
59
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
60
+ # @return [KMeans] The learned cluster analyzer itself.
61
+ def fit(x, _y = nil)
62
+ check_sample_array(x)
63
+ init_cluster_centers(x)
64
+ @params[:max_iter].times do |_t|
65
+ cluster_labels = assign_cluster(x)
66
+ old_centers = @cluster_centers.dup
67
+ @params[:n_clusters].times do |n|
68
+ assigned_bits = cluster_labels.eq(n)
69
+ @cluster_centers[n, true] = x[assigned_bits.where, true].mean(axis: 0) if assigned_bits.count.positive?
70
+ end
71
+ error = Numo::NMath.sqrt(((old_centers - @cluster_centers)**2).sum(axis: 1)).mean
72
+ break if error <= @params[:tol]
73
+ end
74
+ self
75
+ end
76
+
77
+ # Predict cluster labels for samples.
78
+ #
79
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the cluster label.
80
+ # @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
81
+ def predict(x)
82
+ check_sample_array(x)
83
+ assign_cluster(x)
84
+ end
85
+
86
+ # Analysis clusters and assign samples to clusters.
87
+ #
88
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
89
+ # @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
90
+ def fit_predict(x)
91
+ check_sample_array(x)
92
+ fit(x)
93
+ predict(x)
94
+ end
95
+
96
+ # Dump marshal data.
97
+ # @return [Hash] The marshal data.
98
+ def marshal_dump
99
+ { params: @params,
100
+ cluster_centers: @cluster_centers,
101
+ rng: @rng }
102
+ end
103
+
104
+ # Load marshal data.
105
+ # @return [nil]
106
+ def marshal_load(obj)
107
+ @params = obj[:params]
108
+ @cluster_centers = obj[:cluster_centers]
109
+ @rng = obj[:rng]
110
+ nil
111
+ end
112
+
113
+ private
114
+
115
+ def assign_cluster(x)
116
+ distance_matrix = PairwiseMetric.euclidean_distance(x, @cluster_centers)
117
+ distance_matrix.min_index(axis: 1) - Numo::Int32[*0.step(distance_matrix.size - 1, @cluster_centers.shape[0])]
118
+ end
119
+
120
+ def init_cluster_centers(x)
121
+ # random initialize
122
+ n_samples = x.shape[0]
123
+ rand_id = [*0...n_samples].sample(@params[:n_clusters], random: @rng)
124
+ @cluster_centers = x[rand_id, true].dup
125
+ return unless @params[:init] == 'k-means++'
126
+ # k-means++ initialize
127
+ (1...@params[:n_clusters]).each do |n|
128
+ distance_matrix = PairwiseMetric.euclidean_distance(x, @cluster_centers[0...n, true])
129
+ min_distances = distance_matrix.flatten[distance_matrix.min_index(axis: 1)]
130
+ probs = min_distances**2 / (min_distances**2).sum
131
+ cum_probs = probs.cumsum
132
+ selected_id = cum_probs.gt(@rng.rand).where.to_a.first
133
+ @cluster_centers[n, true] = x[selected_id, true].dup
134
+ end
135
+ end
136
+ end
137
+ end
138
+ end