rumale 0.18.3 → 0.18.4

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: a3def8720c4043695c73b01ba991b8d394e8c684bf0fc74458c544d1d17d1512
4
- data.tar.gz: c758d68350a6f97706aab6ecb1934af21aac0cbd9d4ca03cd8dc940c982411b3
3
+ metadata.gz: a4ec0f04029a88ea7950bb74bf9e0bd970cf8b546253da4ef4545244ad0adf68
4
+ data.tar.gz: ed927fb58359fe49d4834593ec2034535f5ac8d37c773d6cb5a53227c4777f61
5
5
  SHA512:
6
- metadata.gz: 23202fa7a3a0f61d137575d5742a58e04ba4f84a912e287013465b5620db1b2b88f96fde72ea4a138c4f4386160c7005ef94c8288b8548576c588aa705febd61
7
- data.tar.gz: 346cf325f959a52ad3b530f4d3a7488b1f313501ad0b244d9a43357e1cc636ef6a3f00781a6a7cdc763bc1e6df291ded3967fea61cedc4ce58a1ca368d76f9e5
6
+ metadata.gz: 58727ddc9c5c6f9c12ac8231f57295988795bf07ec6f748543fff924a4d8085472a210789992ba2a3cc4827300823d141313da1fd20a10be3b5970b40871e4e3
7
+ data.tar.gz: cc1af7ff552a9fe516e0588081ad6645cb03d4c8dd774492ea35c3e1a28eece7383dc88604c41e1cc1417f5d05af1324a7b97c121cb7ee76bcc697e4173195be
@@ -1,3 +1,8 @@
1
+ # 0.18.4
2
+ - Add transformer class for [KernelFDA](https://yoshoku.github.io/rumale/doc/Rumale/KernelMachine/KernelFDA.html).
3
+ - Refactor [KernelPCA](https://yoshoku.github.io/rumale/doc/Rumale/KernelMachine/KernelPCA.html).
4
+ - Fix API documentation.
5
+
1
6
  # 0.18.3
2
7
  - Fix API documentation on [KNeighborsRegressor](https://yoshoku.github.io/rumale/doc/Rumale/NearestNeighbors/KNeighborsRegressor.html)
3
8
  - Refector [rbf_kernel](https://yoshoku.github.io/rumale/doc/Rumale/PairwiseMetric.html#rbf_kernel-class_method) method.
@@ -39,6 +39,7 @@ require 'rumale/linear_model/lasso'
39
39
  require 'rumale/linear_model/elastic_net'
40
40
  require 'rumale/kernel_machine/kernel_svc'
41
41
  require 'rumale/kernel_machine/kernel_pca'
42
+ require 'rumale/kernel_machine/kernel_fda'
42
43
  require 'rumale/kernel_machine/kernel_ridge'
43
44
  require 'rumale/polynomial_model/base_factorization_machine'
44
45
  require 'rumale/polynomial_model/factorization_machine_classifier'
@@ -13,7 +13,7 @@ module Rumale
13
13
  # @example
14
14
  # y_true = Numo::Int32[2, 0, 2, 2, 0, 1]
15
15
  # y_pred = Numo::Int32[0, 0, 2, 2, 0, 2]
16
- # p confusion_matrix(y_true, y_pred)
16
+ # p Rumale::EvaluationMeasure.confusion_matrix(y_true, y_pred)
17
17
  #
18
18
  # # Numo::Int32#shape=[3,3]
19
19
  # # [[2, 0, 0],
@@ -0,0 +1,120 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/base_estimator'
4
+ require 'rumale/base/transformer'
5
+
6
+ module Rumale
7
+ module KernelMachine
8
+ # KernelFDA is a class that implements Kernel Fisher Discriminant Analysis.
9
+ #
10
+ # @example
11
+ # require 'numo/linalg/autoloader'
12
+ #
13
+ # kernel_mat_train = Rumale::PairwiseMetric::rbf_kernel(x_train)
14
+ # kfda = Rumale::KernelMachine::KernelFDA.new
15
+ # mapped_traininig_samples = kfda.fit_transform(kernel_mat_train, y)
16
+ #
17
+ # kernel_mat_test = Rumale::PairwiseMetric::rbf_kernel(x_test, x_train)
18
+ # mapped_test_samples = kfda.transform(kernel_mat_test)
19
+ #
20
+ # *Reference*
21
+ # - Baudat, G. and Anouar, F., "Generalized Discriminant Analysis using a Kernel Approach," Neural Computation, vol. 12, pp. 2385--2404, 2000.
22
+ class KernelFDA
23
+ include Base::BaseEstimator
24
+ include Base::Transformer
25
+
26
+ # Returns the eigenvectors for embedding.
27
+ # @return [Numo::DFloat] (shape: [n_training_sampes, n_components])
28
+ attr_reader :alphas
29
+
30
+ # Create a new transformer with Kernel FDA.
31
+ #
32
+ # @param n_components [Integer] The number of components.
33
+ # @param reg_param [Float] The regularization parameter.
34
+ def initialize(n_components: nil, reg_param: 1e-8)
35
+ check_params_numeric_or_nil(n_components: n_components)
36
+ check_params_numeric(reg_param: reg_param)
37
+ @params = {}
38
+ @params[:n_components] = n_components
39
+ @params[:reg_param] = reg_param
40
+ @alphas = nil
41
+ @row_mean = nil
42
+ @all_mean = nil
43
+ end
44
+
45
+ # Fit the model with given training data.
46
+ # To execute this method, Numo::Linalg must be loaded.
47
+ #
48
+ # @param x [Numo::DFloat] (shape: [n_training_samples, n_training_samples])
49
+ # The kernel matrix of the training data to be used for fitting the model.
50
+ # @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
51
+ # @return [KernelFDA] The learned transformer itself.
52
+ def fit(x, y)
53
+ x = check_convert_sample_array(x)
54
+ y = check_convert_label_array(y)
55
+ check_sample_label_size(x, y)
56
+ raise ArgumentError, 'Expect the kernel matrix of training data to be square.' unless x.shape[0] == x.shape[1]
57
+ raise 'KernelFDA#fit requires Numo::Linalg but that is not loaded.' unless enable_linalg?
58
+
59
+ # initialize some variables.
60
+ n_samples = x.shape[0]
61
+ @classes = Numo::Int32[*y.to_a.uniq.sort]
62
+ n_classes = @classes.size
63
+ n_components = if @params[:n_components].nil?
64
+ [n_samples, n_classes - 1].min
65
+ else
66
+ [n_samples, @params[:n_components]].min
67
+ end
68
+
69
+ # centering
70
+ @row_mean = x.mean(0)
71
+ @all_mean = @row_mean.sum.fdiv(n_samples)
72
+ centered_kernel_mat = x - x.mean(1).expand_dims(1) - @row_mean + @all_mean
73
+
74
+ # calculate between and within scatter matrix.
75
+ class_mat = Numo::DFloat.zeros(n_samples, n_samples)
76
+ @classes.each do |label|
77
+ idx_vec = y.eq(label)
78
+ class_mat += Numo::DFloat.cast(idx_vec).outer(idx_vec) / idx_vec.count
79
+ end
80
+ between_mat = centered_kernel_mat.dot(class_mat).dot(centered_kernel_mat.transpose)
81
+ within_mat = centered_kernel_mat.dot(centered_kernel_mat.transpose) + @params[:reg_param] * Numo::DFloat.eye(n_samples)
82
+
83
+ # calculate projection matrix.
84
+ eig_vals, eig_vecs = Numo::Linalg.eigh(
85
+ between_mat, within_mat,
86
+ vals_range: (n_samples - n_components)...n_samples
87
+ )
88
+ @alphas = eig_vecs.reverse(1).dup
89
+ self
90
+ end
91
+
92
+ # Fit the model with training data, and then transform them with the learned model.
93
+ # To execute this method, Numo::Linalg must be loaded.
94
+ #
95
+ # @param x [Numo::DFloat] (shape: [n_samples, n_samples])
96
+ # The kernel matrix of the training data to be used for fitting the model and transformed.
97
+ # @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
98
+ # @return [Numo::DFloat] (shape: [n_samples, n_components]) The transformed data
99
+ def fit_transform(x, y)
100
+ x = check_convert_sample_array(x)
101
+ y = check_convert_label_array(y)
102
+ check_sample_label_size(x, y)
103
+ fit(x, y).transform(x)
104
+ end
105
+
106
+ # Transform the given data with the learned model.
107
+ #
108
+ # @param x [Numo::DFloat] (shape: [n_testing_samples, n_training_samples])
109
+ # The kernel matrix between testing samples and training samples to be transformed.
110
+ # @return [Numo::DFloat] (shape: [n_testing_samples, n_components]) The transformed data.
111
+ def transform(x)
112
+ x = check_convert_sample_array(x)
113
+ col_mean = x.sum(1) / @row_mean.shape[0]
114
+ centered_kernel_mat = x - col_mean.expand_dims(1) - @row_mean + @all_mean
115
+ transformed = centered_kernel_mat.dot(@alphas)
116
+ @params[:n_components] == 1 ? transformed[true, 0].dup : transformed
117
+ end
118
+ end
119
+ end
120
+ end
@@ -11,7 +11,7 @@ module Rumale
11
11
  # require 'numo/linalg/autoloader'
12
12
  #
13
13
  # kernel_mat_train = Rumale::PairwiseMetric::rbf_kernel(training_samples)
14
- # kpca = Rumale::KernelMachine::KernelPCA(n_components: 2)
14
+ # kpca = Rumale::KernelMachine::KernelPCA.new(n_components: 2)
15
15
  # mapped_traininig_samples = kpca.fit_transform(kernel_mat_train)
16
16
  #
17
17
  # kernel_mat_test = Rumale::PairwiseMetric::rbf_kernel(test_samples, training_samples)
@@ -27,7 +27,7 @@ module Rumale
27
27
  # @return [Numo::DFloat] (shape: [n_components])
28
28
  attr_reader :lambdas
29
29
 
30
- # Returns the eigenvectros of the centered kernel matrix.
30
+ # Returns the eigenvectors of the centered kernel matrix.
31
31
  # @return [Numo::DFloat] (shape: [n_training_sampes, n_components])
32
32
  attr_reader :alphas
33
33
 
@@ -40,6 +40,7 @@ module Rumale
40
40
  @params[:n_components] = n_components
41
41
  @alphas = nil
42
42
  @lambdas = nil
43
+ @transform_mat = nil
43
44
  @row_mean = nil
44
45
  @all_mean = nil
45
46
  end
@@ -63,6 +64,7 @@ module Rumale
63
64
  eig_vals, eig_vecs = Numo::Linalg.eigh(centered_kernel_mat, vals_range: (n_samples - @params[:n_components])...n_samples)
64
65
  @alphas = eig_vecs.reverse(1).dup
65
66
  @lambdas = eig_vals.reverse.dup
67
+ @transform_mat = @alphas.dot((1.0 / Numo::NMath.sqrt(@lambdas)).diag)
66
68
  self
67
69
  end
68
70
 
@@ -87,8 +89,7 @@ module Rumale
87
89
  x = check_convert_sample_array(x)
88
90
  col_mean = x.sum(1) / @row_mean.shape[0]
89
91
  centered_kernel_mat = x - col_mean.expand_dims(1) - @row_mean + @all_mean
90
- transform_mat = @alphas.dot((1.0 / Numo::NMath.sqrt(@lambdas)).diag)
91
- transformed = centered_kernel_mat.dot(transform_mat)
92
+ transformed = centered_kernel_mat.dot(@transform_mat)
92
93
  @params[:n_components] == 1 ? transformed[true, 0].dup : transformed
93
94
  end
94
95
  end
@@ -3,5 +3,5 @@
3
3
  # Rumale is a machine learning library in Ruby.
4
4
  module Rumale
5
5
  # The version of Rumale you are using.
6
- VERSION = '0.18.3'
6
+ VERSION = '0.18.4'
7
7
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: rumale
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.18.3
4
+ version: 0.18.4
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2020-04-04 00:00:00.000000000 Z
11
+ date: 2020-04-11 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -242,6 +242,7 @@ files:
242
242
  - lib/rumale/feature_extraction/hash_vectorizer.rb
243
243
  - lib/rumale/kernel_approximation/nystroem.rb
244
244
  - lib/rumale/kernel_approximation/rbf.rb
245
+ - lib/rumale/kernel_machine/kernel_fda.rb
245
246
  - lib/rumale/kernel_machine/kernel_pca.rb
246
247
  - lib/rumale/kernel_machine/kernel_ridge.rb
247
248
  - lib/rumale/kernel_machine/kernel_svc.rb