rumale 0.8.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (85) hide show
  1. checksums.yaml +7 -0
  2. data/.coveralls.yml +1 -0
  3. data/.gitignore +20 -0
  4. data/.rspec +3 -0
  5. data/.rubocop.yml +47 -0
  6. data/.rubocop_todo.yml +58 -0
  7. data/.travis.yml +13 -0
  8. data/CHANGELOG.md +2 -0
  9. data/CODE_OF_CONDUCT.md +74 -0
  10. data/Gemfile +4 -0
  11. data/LICENSE.txt +23 -0
  12. data/README.md +175 -0
  13. data/Rakefile +6 -0
  14. data/bin/console +14 -0
  15. data/bin/setup +8 -0
  16. data/lib/rumale.rb +70 -0
  17. data/lib/rumale/base/base_estimator.rb +13 -0
  18. data/lib/rumale/base/classifier.rb +36 -0
  19. data/lib/rumale/base/cluster_analyzer.rb +31 -0
  20. data/lib/rumale/base/evaluator.rb +17 -0
  21. data/lib/rumale/base/regressor.rb +36 -0
  22. data/lib/rumale/base/splitter.rb +21 -0
  23. data/lib/rumale/base/transformer.rb +22 -0
  24. data/lib/rumale/clustering/dbscan.rb +125 -0
  25. data/lib/rumale/clustering/k_means.rb +138 -0
  26. data/lib/rumale/dataset.rb +110 -0
  27. data/lib/rumale/decomposition/nmf.rb +141 -0
  28. data/lib/rumale/decomposition/pca.rb +148 -0
  29. data/lib/rumale/ensemble/ada_boost_classifier.rb +196 -0
  30. data/lib/rumale/ensemble/ada_boost_regressor.rb +178 -0
  31. data/lib/rumale/ensemble/random_forest_classifier.rb +180 -0
  32. data/lib/rumale/ensemble/random_forest_regressor.rb +141 -0
  33. data/lib/rumale/evaluation_measure/accuracy.rb +29 -0
  34. data/lib/rumale/evaluation_measure/f_score.rb +50 -0
  35. data/lib/rumale/evaluation_measure/log_loss.rb +45 -0
  36. data/lib/rumale/evaluation_measure/mean_absolute_error.rb +29 -0
  37. data/lib/rumale/evaluation_measure/mean_squared_error.rb +29 -0
  38. data/lib/rumale/evaluation_measure/normalized_mutual_information.rb +62 -0
  39. data/lib/rumale/evaluation_measure/precision.rb +50 -0
  40. data/lib/rumale/evaluation_measure/precision_recall.rb +91 -0
  41. data/lib/rumale/evaluation_measure/purity.rb +40 -0
  42. data/lib/rumale/evaluation_measure/r2_score.rb +43 -0
  43. data/lib/rumale/evaluation_measure/recall.rb +50 -0
  44. data/lib/rumale/kernel_approximation/rbf.rb +121 -0
  45. data/lib/rumale/kernel_machine/kernel_svc.rb +193 -0
  46. data/lib/rumale/linear_model/base_linear_model.rb +89 -0
  47. data/lib/rumale/linear_model/lasso.rb +136 -0
  48. data/lib/rumale/linear_model/linear_regression.rb +110 -0
  49. data/lib/rumale/linear_model/logistic_regression.rb +159 -0
  50. data/lib/rumale/linear_model/ridge.rb +110 -0
  51. data/lib/rumale/linear_model/svc.rb +183 -0
  52. data/lib/rumale/linear_model/svr.rb +122 -0
  53. data/lib/rumale/model_selection/cross_validation.rb +123 -0
  54. data/lib/rumale/model_selection/grid_search_cv.rb +247 -0
  55. data/lib/rumale/model_selection/k_fold.rb +76 -0
  56. data/lib/rumale/model_selection/stratified_k_fold.rb +94 -0
  57. data/lib/rumale/multiclass/one_vs_rest_classifier.rb +100 -0
  58. data/lib/rumale/naive_bayes/naive_bayes.rb +315 -0
  59. data/lib/rumale/nearest_neighbors/k_neighbors_classifier.rb +111 -0
  60. data/lib/rumale/nearest_neighbors/k_neighbors_regressor.rb +93 -0
  61. data/lib/rumale/optimizer/nadam.rb +90 -0
  62. data/lib/rumale/optimizer/rmsprop.rb +69 -0
  63. data/lib/rumale/optimizer/sgd.rb +65 -0
  64. data/lib/rumale/optimizer/yellow_fin.rb +144 -0
  65. data/lib/rumale/pairwise_metric.rb +91 -0
  66. data/lib/rumale/pipeline/pipeline.rb +197 -0
  67. data/lib/rumale/polynomial_model/base_factorization_machine.rb +99 -0
  68. data/lib/rumale/polynomial_model/factorization_machine_classifier.rb +197 -0
  69. data/lib/rumale/polynomial_model/factorization_machine_regressor.rb +131 -0
  70. data/lib/rumale/preprocessing/l2_normalizer.rb +62 -0
  71. data/lib/rumale/preprocessing/label_encoder.rb +94 -0
  72. data/lib/rumale/preprocessing/min_max_scaler.rb +92 -0
  73. data/lib/rumale/preprocessing/one_hot_encoder.rb +98 -0
  74. data/lib/rumale/preprocessing/standard_scaler.rb +86 -0
  75. data/lib/rumale/probabilistic_output.rb +112 -0
  76. data/lib/rumale/tree/base_decision_tree.rb +153 -0
  77. data/lib/rumale/tree/decision_tree_classifier.rb +163 -0
  78. data/lib/rumale/tree/decision_tree_regressor.rb +135 -0
  79. data/lib/rumale/tree/node.rb +70 -0
  80. data/lib/rumale/utils.rb +37 -0
  81. data/lib/rumale/validation.rb +79 -0
  82. data/lib/rumale/values.rb +13 -0
  83. data/lib/rumale/version.rb +6 -0
  84. data/rumale.gemspec +41 -0
  85. metadata +204 -0
@@ -0,0 +1,14 @@
1
+ #!/usr/bin/env ruby
2
+
3
+ require 'bundler/setup'
4
+ require 'rumale'
5
+
6
+ # You can add fixtures and/or initialization code here to make experimenting
7
+ # with your gem easier. You can also use a different console, if you like.
8
+
9
+ # (If you use this, don't forget to add pry to your Gemfile!)
10
+ # require 'pry'
11
+ # Pry.start
12
+
13
+ require 'irb'
14
+ IRB.start(__FILE__)
@@ -0,0 +1,8 @@
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+ IFS=$'\n\t'
4
+ set -vx
5
+
6
+ bundle install
7
+
8
+ # Do any other automated setup that you need to do here
@@ -0,0 +1,70 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'numo/narray'
4
+
5
+ require 'rumale/version'
6
+ require 'rumale/validation'
7
+ require 'rumale/values'
8
+ require 'rumale/utils'
9
+ require 'rumale/pairwise_metric'
10
+ require 'rumale/dataset'
11
+ require 'rumale/probabilistic_output'
12
+ require 'rumale/base/base_estimator'
13
+ require 'rumale/base/classifier'
14
+ require 'rumale/base/regressor'
15
+ require 'rumale/base/cluster_analyzer'
16
+ require 'rumale/base/transformer'
17
+ require 'rumale/base/splitter'
18
+ require 'rumale/base/evaluator'
19
+ require 'rumale/optimizer/sgd'
20
+ require 'rumale/optimizer/rmsprop'
21
+ require 'rumale/optimizer/nadam'
22
+ require 'rumale/optimizer/yellow_fin'
23
+ require 'rumale/pipeline/pipeline'
24
+ require 'rumale/kernel_approximation/rbf'
25
+ require 'rumale/linear_model/base_linear_model'
26
+ require 'rumale/linear_model/svc'
27
+ require 'rumale/linear_model/svr'
28
+ require 'rumale/linear_model/logistic_regression'
29
+ require 'rumale/linear_model/linear_regression'
30
+ require 'rumale/linear_model/ridge'
31
+ require 'rumale/linear_model/lasso'
32
+ require 'rumale/kernel_machine/kernel_svc'
33
+ require 'rumale/polynomial_model/base_factorization_machine'
34
+ require 'rumale/polynomial_model/factorization_machine_classifier'
35
+ require 'rumale/polynomial_model/factorization_machine_regressor'
36
+ require 'rumale/multiclass/one_vs_rest_classifier'
37
+ require 'rumale/nearest_neighbors/k_neighbors_classifier'
38
+ require 'rumale/nearest_neighbors/k_neighbors_regressor'
39
+ require 'rumale/naive_bayes/naive_bayes'
40
+ require 'rumale/tree/node'
41
+ require 'rumale/tree/base_decision_tree'
42
+ require 'rumale/tree/decision_tree_classifier'
43
+ require 'rumale/tree/decision_tree_regressor'
44
+ require 'rumale/ensemble/ada_boost_classifier'
45
+ require 'rumale/ensemble/ada_boost_regressor'
46
+ require 'rumale/ensemble/random_forest_classifier'
47
+ require 'rumale/ensemble/random_forest_regressor'
48
+ require 'rumale/clustering/k_means'
49
+ require 'rumale/clustering/dbscan'
50
+ require 'rumale/decomposition/pca'
51
+ require 'rumale/decomposition/nmf'
52
+ require 'rumale/preprocessing/l2_normalizer'
53
+ require 'rumale/preprocessing/min_max_scaler'
54
+ require 'rumale/preprocessing/standard_scaler'
55
+ require 'rumale/preprocessing/label_encoder'
56
+ require 'rumale/preprocessing/one_hot_encoder'
57
+ require 'rumale/model_selection/k_fold'
58
+ require 'rumale/model_selection/stratified_k_fold'
59
+ require 'rumale/model_selection/cross_validation'
60
+ require 'rumale/model_selection/grid_search_cv'
61
+ require 'rumale/evaluation_measure/accuracy'
62
+ require 'rumale/evaluation_measure/precision'
63
+ require 'rumale/evaluation_measure/recall'
64
+ require 'rumale/evaluation_measure/f_score'
65
+ require 'rumale/evaluation_measure/log_loss'
66
+ require 'rumale/evaluation_measure/r2_score'
67
+ require 'rumale/evaluation_measure/mean_squared_error'
68
+ require 'rumale/evaluation_measure/mean_absolute_error'
69
+ require 'rumale/evaluation_measure/purity'
70
+ require 'rumale/evaluation_measure/normalized_mutual_information'
@@ -0,0 +1,13 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Rumale
4
+ # This module consists of basic mix-in classes.
5
+ module Base
6
+ # Base module for all estimators in Rumale.
7
+ module BaseEstimator
8
+ # Return parameters about an estimator.
9
+ # @return [Hash]
10
+ attr_reader :params
11
+ end
12
+ end
13
+ end
@@ -0,0 +1,36 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/validation'
4
+ require 'rumale/evaluation_measure/accuracy'
5
+
6
+ module Rumale
7
+ module Base
8
+ # Module for all classifiers in Rumale.
9
+ module Classifier
10
+ include Validation
11
+
12
+ # An abstract method for fitting a model.
13
+ def fit
14
+ raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
15
+ end
16
+
17
+ # An abstract method for predicting labels.
18
+ def predict
19
+ raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
20
+ end
21
+
22
+ # Calculate the mean accuracy of the given testing data.
23
+ #
24
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) Testing data.
25
+ # @param y [Numo::Int32] (shape: [n_samples]) True labels for testing data.
26
+ # @return [Float] Mean accuracy
27
+ def score(x, y)
28
+ check_sample_array(x)
29
+ check_label_array(y)
30
+ check_sample_label_size(x, y)
31
+ evaluator = Rumale::EvaluationMeasure::Accuracy.new
32
+ evaluator.score(y, predict(x))
33
+ end
34
+ end
35
+ end
36
+ end
@@ -0,0 +1,31 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/validation'
4
+ require 'rumale/evaluation_measure/purity'
5
+
6
+ module Rumale
7
+ module Base
8
+ # Module for all clustering algorithms in Rumale.
9
+ module ClusterAnalyzer
10
+ include Validation
11
+
12
+ # An abstract method for analyzing clusters and predicting cluster indices.
13
+ def fit_predict
14
+ raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
15
+ end
16
+
17
+ # Calculate purity of clustering result.
18
+ #
19
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) Testing data.
20
+ # @param y [Numo::Int32] (shape: [n_samples]) True labels for testing data.
21
+ # @return [Float] Purity
22
+ def score(x, y)
23
+ check_sample_array(x)
24
+ check_label_array(y)
25
+ check_sample_label_size(x, y)
26
+ evaluator = Rumale::EvaluationMeasure::Purity.new
27
+ evaluator.score(y, fit_predict(x))
28
+ end
29
+ end
30
+ end
31
+ end
@@ -0,0 +1,17 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/validation'
4
+
5
+ module Rumale
6
+ module Base
7
+ # Module for all evaluation measures in Rumale.
8
+ module Evaluator
9
+ include Validation
10
+
11
+ # An abstract method for evaluation of model.
12
+ def score
13
+ raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
14
+ end
15
+ end
16
+ end
17
+ end
@@ -0,0 +1,36 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/validation'
4
+ require 'rumale/evaluation_measure/r2_score'
5
+
6
+ module Rumale
7
+ module Base
8
+ # Module for all regressors in Rumale.
9
+ module Regressor
10
+ include Validation
11
+
12
+ # An abstract method for fitting a model.
13
+ def fit
14
+ raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
15
+ end
16
+
17
+ # An abstract method for predicting labels.
18
+ def predict
19
+ raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
20
+ end
21
+
22
+ # Calculate the coefficient of determination for the given testing data.
23
+ #
24
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) Testing data.
25
+ # @param y [Numo::DFloat] (shape: [n_samples, n_outputs]) Target values for testing data.
26
+ # @return [Float] Coefficient of determination
27
+ def score(x, y)
28
+ check_sample_array(x)
29
+ check_tvalue_array(y)
30
+ check_sample_tvalue_size(x, y)
31
+ evaluator = Rumale::EvaluationMeasure::R2Score.new
32
+ evaluator.score(y, predict(x))
33
+ end
34
+ end
35
+ end
36
+ end
@@ -0,0 +1,21 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/validation'
4
+
5
+ module Rumale
6
+ module Base
7
+ # Module for all validation methods in Rumale.
8
+ module Splitter
9
+ include Validation
10
+
11
+ # Return the number of splits.
12
+ # @return [Integer]
13
+ attr_reader :n_splits
14
+
15
+ # An abstract method for splitting dataset.
16
+ def split
17
+ raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
18
+ end
19
+ end
20
+ end
21
+ end
@@ -0,0 +1,22 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/validation'
4
+
5
+ module Rumale
6
+ module Base
7
+ # Module for all transfomers in Rumale.
8
+ module Transformer
9
+ include Validation
10
+
11
+ # An abstract method for fitting a model.
12
+ def fit
13
+ raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
14
+ end
15
+
16
+ # An abstract method for fitting a model and transforming given data.
17
+ def fit_transform
18
+ raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
19
+ end
20
+ end
21
+ end
22
+ end
@@ -0,0 +1,125 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/base_estimator'
4
+ require 'rumale/base/cluster_analyzer'
5
+ require 'rumale/pairwise_metric'
6
+
7
+ module Rumale
8
+ module Clustering
9
+ # DBSCAN is a class that implements DBSCAN cluster analysis.
10
+ # The current implementation uses the Euclidean distance for analyzing the clusters.
11
+ #
12
+ # @example
13
+ # analyzer = Rumale::Clustering::DBSCAN.new(eps: 0.5, min_samples: 5)
14
+ # cluster_labels = analyzer.fit_predict(samples)
15
+ #
16
+ # *Reference*
17
+ # - M. Ester, H-P. Kriegel, J. Sander, and X. Xu, "A density-based algorithm for discovering clusters in large spatial databases with noise," Proc. KDD' 96, pp. 266--231, 1996.
18
+ class DBSCAN
19
+ include Base::BaseEstimator
20
+ include Base::ClusterAnalyzer
21
+
22
+ # Return the core sample indices.
23
+ # @return [Numo::Int32] (shape: [n_core_samples])
24
+ attr_reader :core_sample_ids
25
+
26
+ # Return the cluster labels. The negative cluster label indicates that the point is noise.
27
+ # @return [Numo::Int32] (shape: [n_samples])
28
+ attr_reader :labels
29
+
30
+ # Create a new cluster analyzer with DBSCAN method.
31
+ #
32
+ # @param eps [Float] The radius of neighborhood.
33
+ # @param min_samples [Integer] The number of neighbor samples to be used for the criterion whether a point is a core point.
34
+ def initialize(eps: 0.5, min_samples: 5)
35
+ check_params_float(eps: eps)
36
+ check_params_integer(min_samples: min_samples)
37
+ @params = {}
38
+ @params[:eps] = eps
39
+ @params[:min_samples] = min_samples
40
+ @core_sample_ids = nil
41
+ @labels = nil
42
+ end
43
+
44
+ # Analysis clusters with given training data.
45
+ #
46
+ # @overload fit(x) -> DBSCAN
47
+ #
48
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
49
+ # @return [DBSCAN] The learned cluster analyzer itself.
50
+ def fit(x, _y = nil)
51
+ check_sample_array(x)
52
+ partial_fit(x)
53
+ self
54
+ end
55
+
56
+ # Analysis clusters and assign samples to clusters.
57
+ #
58
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
59
+ # @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
60
+ def fit_predict(x)
61
+ check_sample_array(x)
62
+ partial_fit(x)
63
+ labels
64
+ end
65
+
66
+ # Dump marshal data.
67
+ # @return [Hash] The marshal data.
68
+ def marshal_dump
69
+ { params: @params,
70
+ core_sample_ids: @core_sample_ids,
71
+ labels: @labels }
72
+ end
73
+
74
+ # Load marshal data.
75
+ # @return [nil]
76
+ def marshal_load(obj)
77
+ @params = obj[:params]
78
+ @core_sample_ids = obj[:core_sample_ids]
79
+ @labels = obj[:labels]
80
+ nil
81
+ end
82
+
83
+ private
84
+
85
+ def partial_fit(x)
86
+ cluster_id = 0
87
+ n_samples = x.shape[0]
88
+ @core_sample_ids = []
89
+ @labels = Numo::Int32.zeros(n_samples) - 2
90
+ n_samples.times do |q|
91
+ next if @labels[q] >= -1
92
+ cluster_id += 1 if expand_cluster(x, q, cluster_id)
93
+ end
94
+ @core_sample_ids = Numo::Int32[*@core_sample_ids.flatten]
95
+ nil
96
+ end
97
+
98
+ def expand_cluster(x, query_id, cluster_id)
99
+ target_ids = region_query(x[query_id, true], x)
100
+ if target_ids.size < @params[:min_samples]
101
+ @labels[query_id] = -1
102
+ false
103
+ else
104
+ @labels[target_ids] = cluster_id
105
+ @core_sample_ids.push(target_ids.dup)
106
+ target_ids.delete(query_id)
107
+ while (m = target_ids.shift)
108
+ neighbor_ids = region_query(x[m, true], x)
109
+ next if neighbor_ids.size < @params[:min_samples]
110
+ neighbor_ids.each do |n|
111
+ target_ids.push(n) if @labels[n] < -1
112
+ @labels[n] = cluster_id if @labels[n] <= -1
113
+ end
114
+ end
115
+ true
116
+ end
117
+ end
118
+
119
+ def region_query(query, targets)
120
+ distance_arr = PairwiseMetric.euclidean_distance(query.expand_dims(0), targets)[0, true]
121
+ distance_arr.lt(@params[:eps]).where.to_a
122
+ end
123
+ end
124
+ end
125
+ end
@@ -0,0 +1,138 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/base_estimator'
4
+ require 'rumale/base/cluster_analyzer'
5
+ require 'rumale/pairwise_metric'
6
+
7
+ module Rumale
8
+ # This module consists of classes that implement cluster analysis methods.
9
+ module Clustering
10
+ # KMeans is a class that implements K-Means cluster analysis.
11
+ # The current implementation uses the Euclidean distance for analyzing the clusters.
12
+ #
13
+ # @example
14
+ # analyzer = Rumale::Clustering::KMeans.new(n_clusters: 10, max_iter: 50)
15
+ # cluster_labels = analyzer.fit_predict(samples)
16
+ #
17
+ # *Reference*
18
+ # - D. Arthur and S. Vassilvitskii, "k-means++: the advantages of careful seeding," Proc. SODA'07, pp. 1027--1035, 2007.
19
+ class KMeans
20
+ include Base::BaseEstimator
21
+ include Base::ClusterAnalyzer
22
+
23
+ # Return the centroids.
24
+ # @return [Numo::DFloat] (shape: [n_clusters, n_features])
25
+ attr_reader :cluster_centers
26
+
27
+ # Return the random generator.
28
+ # @return [Random]
29
+ attr_reader :rng
30
+
31
+ # Create a new cluster analyzer with K-Means method.
32
+ #
33
+ # @param n_clusters [Integer] The number of clusters.
34
+ # @param init [String] The initialization method for centroids ('random' or 'k-means++').
35
+ # @param max_iter [Integer] The maximum number of iterations.
36
+ # @param tol [Float] The tolerance of termination criterion.
37
+ # @param random_seed [Integer] The seed value using to initialize the random generator.
38
+ def initialize(n_clusters: 8, init: 'k-means++', max_iter: 50, tol: 1.0e-4, random_seed: nil)
39
+ check_params_integer(n_clusters: n_clusters, max_iter: max_iter)
40
+ check_params_float(tol: tol)
41
+ check_params_string(init: init)
42
+ check_params_type_or_nil(Integer, random_seed: random_seed)
43
+ check_params_positive(n_clusters: n_clusters, max_iter: max_iter)
44
+ @params = {}
45
+ @params[:n_clusters] = n_clusters
46
+ @params[:init] = init == 'random' ? 'random' : 'k-means++'
47
+ @params[:max_iter] = max_iter
48
+ @params[:tol] = tol
49
+ @params[:random_seed] = random_seed
50
+ @params[:random_seed] ||= srand
51
+ @cluster_centers = nil
52
+ @rng = Random.new(@params[:random_seed])
53
+ end
54
+
55
+ # Analysis clusters with given training data.
56
+ #
57
+ # @overload fit(x) -> KMeans
58
+ #
59
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
60
+ # @return [KMeans] The learned cluster analyzer itself.
61
+ def fit(x, _y = nil)
62
+ check_sample_array(x)
63
+ init_cluster_centers(x)
64
+ @params[:max_iter].times do |_t|
65
+ cluster_labels = assign_cluster(x)
66
+ old_centers = @cluster_centers.dup
67
+ @params[:n_clusters].times do |n|
68
+ assigned_bits = cluster_labels.eq(n)
69
+ @cluster_centers[n, true] = x[assigned_bits.where, true].mean(axis: 0) if assigned_bits.count.positive?
70
+ end
71
+ error = Numo::NMath.sqrt(((old_centers - @cluster_centers)**2).sum(axis: 1)).mean
72
+ break if error <= @params[:tol]
73
+ end
74
+ self
75
+ end
76
+
77
+ # Predict cluster labels for samples.
78
+ #
79
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the cluster label.
80
+ # @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
81
+ def predict(x)
82
+ check_sample_array(x)
83
+ assign_cluster(x)
84
+ end
85
+
86
+ # Analysis clusters and assign samples to clusters.
87
+ #
88
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
89
+ # @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
90
+ def fit_predict(x)
91
+ check_sample_array(x)
92
+ fit(x)
93
+ predict(x)
94
+ end
95
+
96
+ # Dump marshal data.
97
+ # @return [Hash] The marshal data.
98
+ def marshal_dump
99
+ { params: @params,
100
+ cluster_centers: @cluster_centers,
101
+ rng: @rng }
102
+ end
103
+
104
+ # Load marshal data.
105
+ # @return [nil]
106
+ def marshal_load(obj)
107
+ @params = obj[:params]
108
+ @cluster_centers = obj[:cluster_centers]
109
+ @rng = obj[:rng]
110
+ nil
111
+ end
112
+
113
+ private
114
+
115
+ def assign_cluster(x)
116
+ distance_matrix = PairwiseMetric.euclidean_distance(x, @cluster_centers)
117
+ distance_matrix.min_index(axis: 1) - Numo::Int32[*0.step(distance_matrix.size - 1, @cluster_centers.shape[0])]
118
+ end
119
+
120
+ def init_cluster_centers(x)
121
+ # random initialize
122
+ n_samples = x.shape[0]
123
+ rand_id = [*0...n_samples].sample(@params[:n_clusters], random: @rng)
124
+ @cluster_centers = x[rand_id, true].dup
125
+ return unless @params[:init] == 'k-means++'
126
+ # k-means++ initialize
127
+ (1...@params[:n_clusters]).each do |n|
128
+ distance_matrix = PairwiseMetric.euclidean_distance(x, @cluster_centers[0...n, true])
129
+ min_distances = distance_matrix.flatten[distance_matrix.min_index(axis: 1)]
130
+ probs = min_distances**2 / (min_distances**2).sum
131
+ cum_probs = probs.cumsum
132
+ selected_id = cum_probs.gt(@rng.rand).where.to_a.first
133
+ @cluster_centers[n, true] = x[selected_id, true].dup
134
+ end
135
+ end
136
+ end
137
+ end
138
+ end