rumale-metric_learning 0.25.0 → 0.26.0

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 8fbc84de9f21a9fda63e31898cf1fa28b2a29dad7f41821f9cf68d64a881906d
4
- data.tar.gz: bf66f89e3d23653947e8bcdbd3c5054c88ecaca13b51d26c89cf3764f2fa6179
3
+ metadata.gz: 7f20a7e0bf72c2a9869a837c493121ce2ec70132d9818f5bfca3dd2b8dd7db94
4
+ data.tar.gz: af5c7339a4ef9164223c18b6e18c9453b081b90be787cc9a23ccb4fc9493566f
5
5
  SHA512:
6
- metadata.gz: 4d63cd938c957f3963d4497b7067fc77171f977a079d4ff94a2d82acf0c4a1123d8ac50833437c2c1bfd66d3decaeabaa772e2bf8630dd7ed7443788249d8850
7
- data.tar.gz: 97e1cbee0e977bf161fcf82b161e961877afc795eaf2e61475b3d88c2a2f524d3ad4233a9be421d310659d146f94c738ae83c0b4d154f570688671499f22902a
6
+ metadata.gz: ff8a56b94c0834938085284cbc4290390554ebbe030a0a69dd19fc3efc0ff198c35686656e93e0b4b734c120ce88b4da4a7ef3180128b87bf0acff12c56eb2cb
7
+ data.tar.gz: 96eddc81041593518dd8139114e89b59634d658fc20bad5ebdde297a5517914b561182028fe18359da3df83b9c204f6aecd7262208c28fd1caba5423bdfa36cb
@@ -0,0 +1,115 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/estimator'
4
+ require 'rumale/base/transformer'
5
+ require 'rumale/validation'
6
+
7
+ module Rumale
8
+ module MetricLearning
9
+ # LocalFisherDiscriminantAnalysis is a class that implements Local Fisher Discriminant Analysis.
10
+ #
11
+ # @example
12
+ # require 'rumale/metric_learning/local_fisher_discriminant_analysis'
13
+ #
14
+ # transformer = Rumale::MetricLearning::LocalFisherDiscriminantAnalysis.new
15
+ # transformer.fit(training_samples, traininig_labels)
16
+ # low_samples = transformer.transform(testing_samples)
17
+ #
18
+ # *Reference*
19
+ # - Sugiyama, M., "Local Fisher Discriminant Analysis for Supervised Dimensionality Reduction," Proc. ICML'06, pp. 905--912, 2006.
20
+ class LocalFisherDiscriminantAnalysis < ::Rumale::Base::Estimator
21
+ include ::Rumale::Base::Transformer
22
+
23
+ # Returns the transform matrix.
24
+ # @return [Numo::DFloat] (shape: [n_components, n_features])
25
+ attr_reader :components
26
+
27
+ # Return the class labels.
28
+ # @return [Numo::Int32] (shape: [n_classes])
29
+ attr_reader :classes
30
+
31
+ # Create a new transformer with LocalFisherDiscriminantAnalysis.
32
+ #
33
+ # @param n_components [Integer] The number of components.
34
+ # @param gamma [Float] The parameter of rbf kernel, if nil it is 1 / n_features.
35
+ def initialize(n_components: nil, gamma: nil)
36
+ super()
37
+ @params = {
38
+ n_components: n_components,
39
+ gamma: gamma
40
+ }
41
+ end
42
+
43
+ # Fit the model with given training data.
44
+ #
45
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
46
+ # @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
47
+ # @return [LocalFisherDiscriminantAnalysis] The learned classifier itself.
48
+ def fit(x, y)
49
+ unless enable_linalg?(warning: false)
50
+ raise 'LocalFisherDiscriminatAnalysis#fit requires Numo::Linalg but that is not loaded.'
51
+ end
52
+
53
+ x = Rumale::Validation.check_convert_sample_array(x)
54
+ y = Rumale::Validation.check_convert_label_array(y)
55
+ Rumale::Validation.check_sample_size(x, y)
56
+
57
+ # initialize some variables.
58
+ n_samples, n_features = x.shape
59
+ @classes = Numo::Int32[*y.to_a.uniq.sort]
60
+ n_components = @params[:n_components] || n_features
61
+ @params[:gamma] ||= 1.fdiv(n_features)
62
+ affinity_mat = Rumale::PairwiseMetric.rbf_kernel(x, nil, @params[:gamma])
63
+ affinity_mat[affinity_mat.diag_indices] = 1.0
64
+
65
+ # calculate within and mixture scatter matricies.
66
+ class_mat = Numo::DFloat.zeros(n_samples, n_samples)
67
+ within_weight_mat = Numo::DFloat.zeros(n_samples, n_samples)
68
+ @classes.each do |label|
69
+ pos = y.eq(label)
70
+ n_class_samples = pos.count
71
+ pos_vec = Numo::DFloat.cast(pos)
72
+ pos_mat = pos_vec.outer(pos_vec)
73
+ class_mat += pos_mat
74
+ within_weight_mat += pos_mat * 1.fdiv(n_class_samples)
75
+ end
76
+
77
+ mixture_weight_mat = ((affinity_mat - 1) / n_samples) * class_mat + 1.fdiv(n_samples)
78
+ within_weight_mat *= affinity_mat
79
+ mixture_weight_mat = mixture_weight_mat.sum(axis: 1).diag - mixture_weight_mat
80
+ within_weight_mat = within_weight_mat.sum(axis: 1).diag - within_weight_mat
81
+
82
+ # calculate components.
83
+ mixture_mat = x.transpose.dot(mixture_weight_mat.dot(x))
84
+ within_mat = x.transpose.dot(within_weight_mat.dot(x))
85
+ _, evecs = Numo::Linalg.eigh(mixture_mat, within_mat, vals_range: (n_features - n_components)...n_features)
86
+ comps = evecs.reverse(1).transpose.dup
87
+ @components = n_components == 1 ? comps[0, true].dup : comps.dup
88
+ self
89
+ end
90
+
91
+ # Fit the model with training data, and then transform them with the learned model.
92
+ #
93
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
94
+ # @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
95
+ # @return [Numo::DFloat] (shape: [n_samples, n_components]) The transformed data
96
+ def fit_transform(x, y)
97
+ x = Rumale::Validation.check_convert_sample_array(x)
98
+ y = Rumale::Validation.check_convert_label_array(y)
99
+ Rumale::Validation.check_sample_size(x, y)
100
+
101
+ fit(x, y).transform(x)
102
+ end
103
+
104
+ # Transform the given data with the learned model.
105
+ #
106
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The data to be transformed with the learned model.
107
+ # @return [Numo::DFloat] (shape: [n_samples, n_components]) The transformed data.
108
+ def transform(x)
109
+ x = Rumale::Validation.check_convert_sample_array(x)
110
+
111
+ x.dot(@components.transpose)
112
+ end
113
+ end
114
+ end
115
+ end
@@ -5,6 +5,6 @@ module Rumale
5
5
  # Module for metric learning algorithms.
6
6
  module MetricLearning
7
7
  # @!visibility private
8
- VERSION = '0.25.0'
8
+ VERSION = '0.26.0'
9
9
  end
10
10
  end
@@ -3,6 +3,7 @@
3
3
  require 'numo/narray'
4
4
 
5
5
  require_relative 'metric_learning/fisher_discriminant_analysis'
6
+ require_relative 'metric_learning/local_fisher_discriminant_analysis'
6
7
  require_relative 'metric_learning/mlkr'
7
8
  require_relative 'metric_learning/neighbourhood_component_analysis'
8
9
  require_relative 'metric_learning/version'
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: rumale-metric_learning
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.25.0
4
+ version: 0.26.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2023-01-18 00:00:00.000000000 Z
11
+ date: 2023-02-19 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: lbfgsb
@@ -44,28 +44,28 @@ dependencies:
44
44
  requirements:
45
45
  - - "~>"
46
46
  - !ruby/object:Gem::Version
47
- version: 0.25.0
47
+ version: 0.26.0
48
48
  type: :runtime
49
49
  prerelease: false
50
50
  version_requirements: !ruby/object:Gem::Requirement
51
51
  requirements:
52
52
  - - "~>"
53
53
  - !ruby/object:Gem::Version
54
- version: 0.25.0
54
+ version: 0.26.0
55
55
  - !ruby/object:Gem::Dependency
56
56
  name: rumale-decomposition
57
57
  requirement: !ruby/object:Gem::Requirement
58
58
  requirements:
59
59
  - - "~>"
60
60
  - !ruby/object:Gem::Version
61
- version: 0.25.0
61
+ version: 0.26.0
62
62
  type: :runtime
63
63
  prerelease: false
64
64
  version_requirements: !ruby/object:Gem::Requirement
65
65
  requirements:
66
66
  - - "~>"
67
67
  - !ruby/object:Gem::Version
68
- version: 0.25.0
68
+ version: 0.26.0
69
69
  description: |
70
70
  Rumale::MetricLearning provides metric learning algorithms,
71
71
  such as Fisher Discriminant Analysis and Neighboourhood Component Analysis
@@ -80,6 +80,7 @@ files:
80
80
  - README.md
81
81
  - lib/rumale/metric_learning.rb
82
82
  - lib/rumale/metric_learning/fisher_discriminant_analysis.rb
83
+ - lib/rumale/metric_learning/local_fisher_discriminant_analysis.rb
83
84
  - lib/rumale/metric_learning/mlkr.rb
84
85
  - lib/rumale/metric_learning/neighbourhood_component_analysis.rb
85
86
  - lib/rumale/metric_learning/version.rb