svmkit 0.7.3 → 0.8.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (78) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +0 -9
  3. data/.rspec +1 -0
  4. data/.travis.yml +4 -12
  5. data/LICENSE.txt +1 -1
  6. data/README.md +11 -13
  7. data/lib/svmkit.rb +3 -66
  8. data/svmkit.gemspec +12 -7
  9. metadata +16 -81
  10. data/.coveralls.yml +0 -1
  11. data/.rubocop.yml +0 -47
  12. data/.rubocop_todo.yml +0 -58
  13. data/HISTORY.md +0 -168
  14. data/lib/svmkit/base/base_estimator.rb +0 -13
  15. data/lib/svmkit/base/classifier.rb +0 -34
  16. data/lib/svmkit/base/cluster_analyzer.rb +0 -29
  17. data/lib/svmkit/base/evaluator.rb +0 -13
  18. data/lib/svmkit/base/regressor.rb +0 -34
  19. data/lib/svmkit/base/splitter.rb +0 -17
  20. data/lib/svmkit/base/transformer.rb +0 -18
  21. data/lib/svmkit/clustering/dbscan.rb +0 -127
  22. data/lib/svmkit/clustering/k_means.rb +0 -140
  23. data/lib/svmkit/dataset.rb +0 -109
  24. data/lib/svmkit/decomposition/nmf.rb +0 -147
  25. data/lib/svmkit/decomposition/pca.rb +0 -150
  26. data/lib/svmkit/ensemble/ada_boost_classifier.rb +0 -198
  27. data/lib/svmkit/ensemble/ada_boost_regressor.rb +0 -180
  28. data/lib/svmkit/ensemble/random_forest_classifier.rb +0 -182
  29. data/lib/svmkit/ensemble/random_forest_regressor.rb +0 -143
  30. data/lib/svmkit/evaluation_measure/accuracy.rb +0 -30
  31. data/lib/svmkit/evaluation_measure/f_score.rb +0 -51
  32. data/lib/svmkit/evaluation_measure/log_loss.rb +0 -46
  33. data/lib/svmkit/evaluation_measure/mean_absolute_error.rb +0 -30
  34. data/lib/svmkit/evaluation_measure/mean_squared_error.rb +0 -30
  35. data/lib/svmkit/evaluation_measure/normalized_mutual_information.rb +0 -63
  36. data/lib/svmkit/evaluation_measure/precision.rb +0 -51
  37. data/lib/svmkit/evaluation_measure/precision_recall.rb +0 -91
  38. data/lib/svmkit/evaluation_measure/purity.rb +0 -41
  39. data/lib/svmkit/evaluation_measure/r2_score.rb +0 -44
  40. data/lib/svmkit/evaluation_measure/recall.rb +0 -51
  41. data/lib/svmkit/kernel_approximation/rbf.rb +0 -136
  42. data/lib/svmkit/kernel_machine/kernel_svc.rb +0 -194
  43. data/lib/svmkit/linear_model/lasso.rb +0 -138
  44. data/lib/svmkit/linear_model/linear_regression.rb +0 -112
  45. data/lib/svmkit/linear_model/logistic_regression.rb +0 -161
  46. data/lib/svmkit/linear_model/ridge.rb +0 -112
  47. data/lib/svmkit/linear_model/sgd_linear_estimator.rb +0 -89
  48. data/lib/svmkit/linear_model/svc.rb +0 -184
  49. data/lib/svmkit/linear_model/svr.rb +0 -123
  50. data/lib/svmkit/model_selection/cross_validation.rb +0 -121
  51. data/lib/svmkit/model_selection/grid_search_cv.rb +0 -247
  52. data/lib/svmkit/model_selection/k_fold.rb +0 -77
  53. data/lib/svmkit/model_selection/stratified_k_fold.rb +0 -95
  54. data/lib/svmkit/multiclass/one_vs_rest_classifier.rb +0 -101
  55. data/lib/svmkit/naive_bayes/naive_bayes.rb +0 -316
  56. data/lib/svmkit/nearest_neighbors/k_neighbors_classifier.rb +0 -112
  57. data/lib/svmkit/nearest_neighbors/k_neighbors_regressor.rb +0 -94
  58. data/lib/svmkit/optimizer/nadam.rb +0 -90
  59. data/lib/svmkit/optimizer/rmsprop.rb +0 -69
  60. data/lib/svmkit/optimizer/sgd.rb +0 -65
  61. data/lib/svmkit/optimizer/yellow_fin.rb +0 -144
  62. data/lib/svmkit/pairwise_metric.rb +0 -91
  63. data/lib/svmkit/pipeline/pipeline.rb +0 -197
  64. data/lib/svmkit/polynomial_model/factorization_machine_classifier.rb +0 -262
  65. data/lib/svmkit/polynomial_model/factorization_machine_regressor.rb +0 -194
  66. data/lib/svmkit/preprocessing/l2_normalizer.rb +0 -63
  67. data/lib/svmkit/preprocessing/label_encoder.rb +0 -95
  68. data/lib/svmkit/preprocessing/min_max_scaler.rb +0 -93
  69. data/lib/svmkit/preprocessing/one_hot_encoder.rb +0 -99
  70. data/lib/svmkit/preprocessing/standard_scaler.rb +0 -87
  71. data/lib/svmkit/probabilistic_output.rb +0 -112
  72. data/lib/svmkit/tree/decision_tree_classifier.rb +0 -276
  73. data/lib/svmkit/tree/decision_tree_regressor.rb +0 -251
  74. data/lib/svmkit/tree/node.rb +0 -70
  75. data/lib/svmkit/utils.rb +0 -22
  76. data/lib/svmkit/validation.rb +0 -79
  77. data/lib/svmkit/values.rb +0 -13
  78. data/lib/svmkit/version.rb +0 -7
@@ -1,140 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'svmkit/validation'
4
- require 'svmkit/base/base_estimator'
5
- require 'svmkit/base/cluster_analyzer'
6
- require 'svmkit/pairwise_metric'
7
-
8
- module SVMKit
9
- # This module consists of classes that implement cluster analysis methods.
10
- module Clustering
11
- # KMeans is a class that implements K-Means cluster analysis.
12
- # The current implementation uses the Euclidean distance for analyzing the clusters.
13
- #
14
- # @example
15
- # analyzer = SVMKit::Clustering::KMeans.new(n_clusters: 10, max_iter: 50)
16
- # cluster_labels = analyzer.fit_predict(samples)
17
- #
18
- # *Reference*
19
- # - D. Arthur and S. Vassilvitskii, "k-means++: the advantages of careful seeding," Proc. SODA'07, pp. 1027--1035, 2007.
20
- class KMeans
21
- include Base::BaseEstimator
22
- include Base::ClusterAnalyzer
23
- include Validation
24
-
25
- # Return the centroids.
26
- # @return [Numo::DFloat] (shape: [n_clusters, n_features])
27
- attr_reader :cluster_centers
28
-
29
- # Return the random generator.
30
- # @return [Random]
31
- attr_reader :rng
32
-
33
- # Create a new cluster analyzer with K-Means method.
34
- #
35
- # @param n_clusters [Integer] The number of clusters.
36
- # @param init [String] The initialization method for centroids ('random' or 'k-means++').
37
- # @param max_iter [Integer] The maximum number of iterations.
38
- # @param tol [Float] The tolerance of termination criterion.
39
- # @param random_seed [Integer] The seed value using to initialize the random generator.
40
- def initialize(n_clusters: 8, init: 'k-means++', max_iter: 50, tol: 1.0e-4, random_seed: nil)
41
- check_params_integer(n_clusters: n_clusters, max_iter: max_iter)
42
- check_params_float(tol: tol)
43
- check_params_string(init: init)
44
- check_params_type_or_nil(Integer, random_seed: random_seed)
45
- check_params_positive(n_clusters: n_clusters, max_iter: max_iter)
46
- @params = {}
47
- @params[:n_clusters] = n_clusters
48
- @params[:init] = init == 'random' ? 'random' : 'k-means++'
49
- @params[:max_iter] = max_iter
50
- @params[:tol] = tol
51
- @params[:random_seed] = random_seed
52
- @params[:random_seed] ||= srand
53
- @cluster_centers = nil
54
- @rng = Random.new(@params[:random_seed])
55
- end
56
-
57
- # Analysis clusters with given training data.
58
- #
59
- # @overload fit(x) -> KMeans
60
- #
61
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
62
- # @return [KMeans] The learned cluster analyzer itself.
63
- def fit(x, _y = nil)
64
- check_sample_array(x)
65
- init_cluster_centers(x)
66
- @params[:max_iter].times do |_t|
67
- cluster_labels = assign_cluster(x)
68
- old_centers = @cluster_centers.dup
69
- @params[:n_clusters].times do |n|
70
- assigned_bits = cluster_labels.eq(n)
71
- @cluster_centers[n, true] = x[assigned_bits.where, true].mean(axis: 0) if assigned_bits.count > 0
72
- end
73
- error = Numo::NMath.sqrt(((old_centers - @cluster_centers)**2).sum(axis: 1)).mean
74
- break if error <= @params[:tol]
75
- end
76
- self
77
- end
78
-
79
- # Predict cluster labels for samples.
80
- #
81
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the cluster label.
82
- # @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
83
- def predict(x)
84
- check_sample_array(x)
85
- assign_cluster(x)
86
- end
87
-
88
- # Analysis clusters and assign samples to clusters.
89
- #
90
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
91
- # @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
92
- def fit_predict(x)
93
- check_sample_array(x)
94
- fit(x)
95
- predict(x)
96
- end
97
-
98
- # Dump marshal data.
99
- # @return [Hash] The marshal data.
100
- def marshal_dump
101
- { params: @params,
102
- cluster_centers: @cluster_centers,
103
- rng: @rng }
104
- end
105
-
106
- # Load marshal data.
107
- # @return [nil]
108
- def marshal_load(obj)
109
- @params = obj[:params]
110
- @cluster_centers = obj[:cluster_centers]
111
- @rng = obj[:rng]
112
- nil
113
- end
114
-
115
- private
116
-
117
- def assign_cluster(x)
118
- distance_matrix = PairwiseMetric.euclidean_distance(x, @cluster_centers)
119
- distance_matrix.min_index(axis: 1) - Numo::Int32[*0.step(distance_matrix.size - 1, @cluster_centers.shape[0])]
120
- end
121
-
122
- def init_cluster_centers(x)
123
- # random initialize
124
- n_samples = x.shape[0]
125
- rand_id = [*0...n_samples].sample(@params[:n_clusters], random: @rng)
126
- @cluster_centers = x[rand_id, true].dup
127
- return unless @params[:init] == 'k-means++'
128
- # k-means++ initialize
129
- (1...@params[:n_clusters]).each do |n|
130
- distance_matrix = PairwiseMetric.euclidean_distance(x, @cluster_centers[0...n, true])
131
- min_distances = distance_matrix.flatten[distance_matrix.min_index(axis: 1)]
132
- probs = min_distances**2 / (min_distances**2).sum
133
- cum_probs = probs.cumsum
134
- selected_id = cum_probs.gt(@rng.rand).where.to_a.first
135
- @cluster_centers[n, true] = x[selected_id, true].dup
136
- end
137
- end
138
- end
139
- end
140
- end
@@ -1,109 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'csv'
4
-
5
- module SVMKit
6
- # Module for loading and saving a dataset file.
7
- module Dataset
8
- class << self
9
- # Load a dataset with the libsvm file format into Numo::NArray.
10
- #
11
- # @param filename [String] A path to a dataset file.
12
- # @param zero_based [Boolean] Whether the column index starts from 0 (true) or 1 (false).
13
- #
14
- # @return [Array<Numo::NArray>]
15
- # Returns array containing the (n_samples x n_features) matrix for feature vectors
16
- # and (n_samples) vector for labels or target values.
17
- def load_libsvm_file(filename, zero_based: false)
18
- ftvecs = []
19
- labels = []
20
- n_features = 0
21
- CSV.foreach(filename, col_sep: "\s", headers: false) do |line|
22
- label, ftvec, max_idx = parse_libsvm_line(line, zero_based)
23
- labels.push(label)
24
- ftvecs.push(ftvec)
25
- n_features = max_idx if n_features < max_idx
26
- end
27
- [convert_to_matrix(ftvecs, n_features), Numo::NArray.asarray(labels)]
28
- end
29
-
30
- # Dump the dataset with the libsvm file format.
31
- #
32
- # @param data [Numo::NArray] (shape: [n_samples, n_features]) matrix consisting of feature vectors.
33
- # @param labels [Numo::NArray] (shape: [n_samples]) matrix consisting of labels or target values.
34
- # @param filename [String] A path to the output libsvm file.
35
- # @param zero_based [Boolean] Whether the column index starts from 0 (true) or 1 (false).
36
- def dump_libsvm_file(data, labels, filename, zero_based: false)
37
- n_samples = [data.shape[0], labels.shape[0]].min
38
- single_label = labels.shape[1].nil?
39
- label_type = detect_dtype(labels)
40
- value_type = detect_dtype(data)
41
- File.open(filename, 'w') do |file|
42
- n_samples.times do |n|
43
- label = single_label ? labels[n] : labels[n, true].to_a
44
- file.puts(dump_libsvm_line(label, data[n, true],
45
- label_type, value_type, zero_based))
46
- end
47
- end
48
- end
49
-
50
- private
51
-
52
- def parse_libsvm_line(line, zero_based)
53
- label = parse_label(line.shift)
54
- adj_idx = zero_based == false ? 1 : 0
55
- max_idx = -1
56
- ftvec = []
57
- while (el = line.shift)
58
- idx, val = el.split(':')
59
- idx = idx.to_i - adj_idx
60
- val = val.to_i.to_s == val ? val.to_i : val.to_f
61
- max_idx = idx if max_idx < idx
62
- ftvec.push([idx, val])
63
- end
64
- [label, ftvec, max_idx]
65
- end
66
-
67
- def parse_label(label)
68
- lbl_arr = label.split(',').map { |lbl| lbl.to_i.to_s == lbl ? lbl.to_i : lbl.to_f }
69
- lbl_arr.size > 1 ? lbl_arr : lbl_arr[0]
70
- end
71
-
72
- def convert_to_matrix(data, n_features)
73
- mat = []
74
- data.each do |ft|
75
- vec = Array.new(n_features) { 0 }
76
- ft.each { |el| vec[el[0]] = el[1] }
77
- mat.push(vec)
78
- end
79
- Numo::NArray.asarray(mat)
80
- end
81
-
82
- def detect_dtype(data)
83
- arr_type_str = Numo::NArray.array_type(data).to_s
84
- type = '%s'
85
- type = '%d' if ['Numo::Int8', 'Numo::Int16', 'Numo::Int32', 'Numo::Int64'].include?(arr_type_str)
86
- type = '%d' if ['Numo::UInt8', 'Numo::UInt16', 'Numo::UInt32', 'Numo::UInt64'].include?(arr_type_str)
87
- type = '%.10g' if ['Numo::SFloat', 'Numo::DFloat'].include?(arr_type_str)
88
- type
89
- end
90
-
91
- def dump_libsvm_line(label, ftvec, label_type, value_type, zero_based)
92
- line = dump_label(label, label_type.to_s)
93
- ftvec.to_a.each_with_index do |val, n|
94
- idx = n + (zero_based == false ? 1 : 0)
95
- line += format(" %d:#{value_type}", idx, val) if val != 0.0
96
- end
97
- line
98
- end
99
-
100
- def dump_label(label, label_type_str)
101
- if label.is_a?(Array)
102
- label.map { |lbl| format(label_type_str, lbl) }.join(',')
103
- else
104
- format(label_type_str, label)
105
- end
106
- end
107
- end
108
- end
109
- end
@@ -1,147 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'svmkit/validation'
4
- require 'svmkit/base/base_estimator'
5
- require 'svmkit/base/transformer'
6
-
7
- module SVMKit
8
- module Decomposition
9
- # NMF is a class that implements Non-negative Matrix Factorization.
10
- #
11
- # @example
12
- # decomposer = SVMKit::Decomposition::NMF.new(n_components: 2)
13
- # representaion = decomposer.fit_transform(samples)
14
- #
15
- # *Reference*
16
- # - W. Xu, X. Liu, and Y.Gong, "Document Clustering Based On Non-negative Matrix Factorization," Proc. SIGIR' 03 , pp. 267--273, 2003.
17
- class NMF
18
- include Base::BaseEstimator
19
- include Base::Transformer
20
- include Validation
21
-
22
- # Returns the factorization matrix.
23
- # @return [Numo::DFloat] (shape: [n_components, n_features])
24
- attr_reader :components
25
-
26
- # Return the random generator.
27
- # @return [Random]
28
- attr_reader :rng
29
-
30
- # Create a new transformer with NMF.
31
- #
32
- # @param n_components [Integer] The number of components.
33
- # @param max_iter [Integer] The maximum number of iterations.
34
- # @param tol [Float] The tolerance of termination criterion.
35
- # @param eps [Float] A small value close to zero to avoid zero division error.
36
- # @param random_seed [Integer] The seed value using to initialize the random generator.
37
- def initialize(n_components: 2, max_iter: 500, tol: 1.0e-4, eps: 1.0e-16, random_seed: nil)
38
- check_params_integer(n_components: n_components, max_iter: max_iter)
39
- check_params_float(tol: tol, eps: eps)
40
- check_params_type_or_nil(Integer, random_seed: random_seed)
41
- check_params_positive(n_components: n_components, max_iter: max_iter, tol: tol, eps: eps)
42
- @params = {}
43
- @params[:n_components] = n_components
44
- @params[:max_iter] = max_iter
45
- @params[:tol] = tol
46
- @params[:eps] = eps
47
- @params[:random_seed] = random_seed
48
- @params[:random_seed] ||= srand
49
- @components = nil
50
- @rng = Random.new(@params[:random_seed])
51
- end
52
-
53
- # Fit the model with given training data.
54
- #
55
- # @overload fit(x) -> NMF
56
- #
57
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
58
- # @return [NMF] The learned transformer itself.
59
- def fit(x, _y = nil)
60
- check_sample_array(x)
61
- partial_fit(x)
62
- self
63
- end
64
-
65
- # Fit the model with training data, and then transform them with the learned model.
66
- #
67
- # @overload fit_transform(x) -> Numo::DFloat
68
- #
69
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
70
- # @return [Numo::DFloat] (shape: [n_samples, n_components]) The transformed data
71
- def fit_transform(x, _y = nil)
72
- check_sample_array(x)
73
- partial_fit(x)
74
- end
75
-
76
- # Transform the given data with the learned model.
77
- #
78
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The data to be transformed with the learned model.
79
- # @return [Numo::DFloat] (shape: [n_samples, n_components]) The transformed data.
80
- def transform(x)
81
- check_sample_array(x)
82
- partial_fit(x, false)
83
- end
84
-
85
- # Inverse transform the given transformed data with the learned model.
86
- #
87
- # @param z [Numo::DFloat] (shape: [n_samples, n_components]) The data to be restored into original space with the learned model.
88
- # @return [Numo::DFloat] (shape: [n_samples, n_featuress]) The restored data.
89
- def inverse_transform(z)
90
- check_sample_array(z)
91
- z.dot(@components)
92
- end
93
-
94
- # Dump marshal data.
95
- # @return [Hash] The marshal data.
96
- def marshal_dump
97
- { params: @params,
98
- components: @components,
99
- rng: @rng }
100
- end
101
-
102
- # Load marshal data.
103
- # @return [nil]
104
- def marshal_load(obj)
105
- @params = obj[:params]
106
- @components = obj[:components]
107
- @rng = obj[:rng]
108
- nil
109
- end
110
-
111
- private
112
-
113
- def partial_fit(x, update_comps = true)
114
- # initialize some variables.
115
- n_samples, n_features = x.shape
116
- scale = Math.sqrt(x.mean / @params[:n_components])
117
- @components = rand_uniform([@params[:n_components], n_features]) * scale if update_comps
118
- coefficients = rand_uniform([n_samples, @params[:n_components]]) * scale
119
- # optimization.
120
- @params[:max_iter].times do
121
- # update
122
- if update_comps
123
- nume = coefficients.transpose.dot(x)
124
- deno = coefficients.transpose.dot(coefficients).dot(@components) + @params[:eps]
125
- @components *= (nume / deno)
126
- end
127
- nume = x.dot(@components.transpose)
128
- deno = coefficients.dot(@components).dot(@components.transpose) + @params[:eps]
129
- coefficients *= (nume / deno)
130
- # normalize
131
- norm = Numo::NMath.sqrt((@components**2).sum(1)) + @params[:eps]
132
- @components /= norm.expand_dims(1) if update_comps
133
- coefficients *= norm
134
- # check convergence
135
- err = ((x - coefficients.dot(@components))**2).sum(1).mean
136
- break if err < @params[:tol]
137
- end
138
- coefficients
139
- end
140
-
141
- def rand_uniform(shape)
142
- rnd_vals = Array.new(shape.inject(:*)) { @rng.rand }
143
- Numo::DFloat.asarray(rnd_vals).reshape(shape[0], shape[1])
144
- end
145
- end
146
- end
147
- end
@@ -1,150 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'svmkit/validation'
4
- require 'svmkit/base/base_estimator'
5
- require 'svmkit/base/transformer'
6
-
7
- module SVMKit
8
- # Module for matrix decomposition algorithms.
9
- module Decomposition
10
- # PCA is a class that implements Principal Component Analysis.
11
- #
12
- # @example
13
- # decomposer = SVMKit::Decomposition::PCA.new(n_components: 2)
14
- # representaion = decomposer.fit_transform(samples)
15
- #
16
- # *Reference*
17
- # - A. Sharma and K K. Paliwal, "Fast principal component analysis using fixed-point algorithm," Pattern Recognition Letters, 28, pp. 1151--1155, 2007.
18
- class PCA
19
- include Base::BaseEstimator
20
- include Base::Transformer
21
- include Validation
22
-
23
- # Returns the principal components.
24
- # @return [Numo::DFloat] (shape: [n_components, n_features])
25
- attr_reader :components
26
-
27
- # Returns the mean vector.
28
- # @return [Numo::DFloat] (shape: [n_features]
29
- attr_reader :mean
30
-
31
- # Return the random generator.
32
- # @return [Random]
33
- attr_reader :rng
34
-
35
- # Create a new transformer with PCA.
36
- #
37
- # @param n_components [Integer] The number of principal components.
38
- # @param max_iter [Integer] The maximum number of iterations.
39
- # @param tol [Float] The tolerance of termination criterion.
40
- # @param random_seed [Integer] The seed value using to initialize the random generator.
41
- def initialize(n_components: 2, max_iter: 100, tol: 1.0e-4, random_seed: nil)
42
- check_params_integer(n_components: n_components, max_iter: max_iter)
43
- check_params_float(tol: tol)
44
- check_params_type_or_nil(Integer, random_seed: random_seed)
45
- check_params_positive(n_components: n_components, max_iter: max_iter, tol: tol)
46
- @params = {}
47
- @params[:n_components] = n_components
48
- @params[:max_iter] = max_iter
49
- @params[:tol] = tol
50
- @params[:random_seed] = random_seed
51
- @params[:random_seed] ||= srand
52
- @components = nil
53
- @mean = nil
54
- @rng = Random.new(@params[:random_seed])
55
- end
56
-
57
- # Fit the model with given training data.
58
- #
59
- # @overload fit(x) -> PCA
60
- #
61
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
62
- # @return [PCA] The learned transformer itself.
63
- def fit(x, _y = nil)
64
- check_sample_array(x)
65
- # initialize some variables.
66
- @components = nil
67
- n_samples, n_features = x.shape
68
- # centering.
69
- @mean = x.mean(0)
70
- centered_x = x - @mean
71
- # optimization.
72
- covariance_mat = centered_x.transpose.dot(centered_x) / (n_samples - 1)
73
- @params[:n_components].times do
74
- comp_vec = random_vec(n_features)
75
- @params[:max_iter].times do
76
- updated = orthogonalize(covariance_mat.dot(comp_vec))
77
- break if (updated.dot(comp_vec) - 1).abs < @params[:tol]
78
- comp_vec = updated
79
- end
80
- @components = @components.nil? ? comp_vec : Numo::NArray.vstack([@components, comp_vec])
81
- end
82
- self
83
- end
84
-
85
- # Fit the model with training data, and then transform them with the learned model.
86
- #
87
- # @overload fit_transform(x) -> Numo::DFloat
88
- #
89
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
90
- # @return [Numo::DFloat] (shape: [n_samples, n_components]) The transformed data
91
- def fit_transform(x, _y = nil)
92
- check_sample_array(x)
93
- fit(x).transform(x)
94
- end
95
-
96
- # Transform the given data with the learned model.
97
- #
98
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The data to be transformed with the learned model.
99
- # @return [Numo::DFloat] (shape: [n_samples, n_components]) The transformed data.
100
- def transform(x)
101
- check_sample_array(x)
102
- (x - @mean).dot(@components.transpose)
103
- end
104
-
105
- # Inverse transform the given transformed data with the learned model.
106
- #
107
- # @param z [Numo::DFloat] (shape: [n_samples, n_components]) The data to be restored into original space with the learned model.
108
- # @return [Numo::DFloat] (shape: [n_samples, n_featuress]) The restored data.
109
- def inverse_transform(z)
110
- check_sample_array(z)
111
- c = @components.shape[1].nil? ? @components.expand_dims(0) : @components
112
- z.dot(c) + @mean
113
- end
114
-
115
- # Dump marshal data.
116
- # @return [Hash] The marshal data.
117
- def marshal_dump
118
- { params: @params,
119
- components: @components,
120
- mean: @mean,
121
- rng: @rng }
122
- end
123
-
124
- # Load marshal data.
125
- # @return [nil]
126
- def marshal_load(obj)
127
- @params = obj[:params]
128
- @components = obj[:components]
129
- @mean = obj[:mean]
130
- @rng = obj[:rng]
131
- nil
132
- end
133
-
134
- private
135
-
136
- def orthogonalize(pcvec)
137
- unless @components.nil?
138
- delta = @components.dot(pcvec) * @components.transpose
139
- delta = delta.sum(1) unless delta.shape[1].nil?
140
- pcvec -= delta
141
- end
142
- pcvec / Math.sqrt((pcvec**2).sum.abs) + 1.0e-12
143
- end
144
-
145
- def random_vec(n_features)
146
- Numo::DFloat[*(Array.new(n_features) { @rng.rand })]
147
- end
148
- end
149
- end
150
- end