rumale-metric_learning 0.25.0 → 0.26.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 8fbc84de9f21a9fda63e31898cf1fa28b2a29dad7f41821f9cf68d64a881906d
4
- data.tar.gz: bf66f89e3d23653947e8bcdbd3c5054c88ecaca13b51d26c89cf3764f2fa6179
3
+ metadata.gz: 7f20a7e0bf72c2a9869a837c493121ce2ec70132d9818f5bfca3dd2b8dd7db94
4
+ data.tar.gz: af5c7339a4ef9164223c18b6e18c9453b081b90be787cc9a23ccb4fc9493566f
5
5
  SHA512:
6
- metadata.gz: 4d63cd938c957f3963d4497b7067fc77171f977a079d4ff94a2d82acf0c4a1123d8ac50833437c2c1bfd66d3decaeabaa772e2bf8630dd7ed7443788249d8850
7
- data.tar.gz: 97e1cbee0e977bf161fcf82b161e961877afc795eaf2e61475b3d88c2a2f524d3ad4233a9be421d310659d146f94c738ae83c0b4d154f570688671499f22902a
6
+ metadata.gz: ff8a56b94c0834938085284cbc4290390554ebbe030a0a69dd19fc3efc0ff198c35686656e93e0b4b734c120ce88b4da4a7ef3180128b87bf0acff12c56eb2cb
7
+ data.tar.gz: 96eddc81041593518dd8139114e89b59634d658fc20bad5ebdde297a5517914b561182028fe18359da3df83b9c204f6aecd7262208c28fd1caba5423bdfa36cb
@@ -0,0 +1,115 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/estimator'
4
+ require 'rumale/base/transformer'
5
+ require 'rumale/validation'
6
+
7
+ module Rumale
8
+ module MetricLearning
9
+ # LocalFisherDiscriminantAnalysis is a class that implements Local Fisher Discriminant Analysis.
10
+ #
11
+ # @example
12
+ # require 'rumale/metric_learning/local_fisher_discriminant_analysis'
13
+ #
14
+ # transformer = Rumale::MetricLearning::LocalFisherDiscriminantAnalysis.new
15
+ # transformer.fit(training_samples, traininig_labels)
16
+ # low_samples = transformer.transform(testing_samples)
17
+ #
18
+ # *Reference*
19
+ # - Sugiyama, M., "Local Fisher Discriminant Analysis for Supervised Dimensionality Reduction," Proc. ICML'06, pp. 905--912, 2006.
20
+ class LocalFisherDiscriminantAnalysis < ::Rumale::Base::Estimator
21
+ include ::Rumale::Base::Transformer
22
+
23
+ # Returns the transform matrix.
24
+ # @return [Numo::DFloat] (shape: [n_components, n_features])
25
+ attr_reader :components
26
+
27
+ # Return the class labels.
28
+ # @return [Numo::Int32] (shape: [n_classes])
29
+ attr_reader :classes
30
+
31
+ # Create a new transformer with LocalFisherDiscriminantAnalysis.
32
+ #
33
+ # @param n_components [Integer] The number of components.
34
+ # @param gamma [Float] The parameter of rbf kernel, if nil it is 1 / n_features.
35
+ def initialize(n_components: nil, gamma: nil)
36
+ super()
37
+ @params = {
38
+ n_components: n_components,
39
+ gamma: gamma
40
+ }
41
+ end
42
+
43
+ # Fit the model with given training data.
44
+ #
45
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
46
+ # @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
47
+ # @return [LocalFisherDiscriminantAnalysis] The learned classifier itself.
48
+ def fit(x, y)
49
+ unless enable_linalg?(warning: false)
50
+ raise 'LocalFisherDiscriminatAnalysis#fit requires Numo::Linalg but that is not loaded.'
51
+ end
52
+
53
+ x = Rumale::Validation.check_convert_sample_array(x)
54
+ y = Rumale::Validation.check_convert_label_array(y)
55
+ Rumale::Validation.check_sample_size(x, y)
56
+
57
+ # initialize some variables.
58
+ n_samples, n_features = x.shape
59
+ @classes = Numo::Int32[*y.to_a.uniq.sort]
60
+ n_components = @params[:n_components] || n_features
61
+ @params[:gamma] ||= 1.fdiv(n_features)
62
+ affinity_mat = Rumale::PairwiseMetric.rbf_kernel(x, nil, @params[:gamma])
63
+ affinity_mat[affinity_mat.diag_indices] = 1.0
64
+
65
+ # calculate within and mixture scatter matricies.
66
+ class_mat = Numo::DFloat.zeros(n_samples, n_samples)
67
+ within_weight_mat = Numo::DFloat.zeros(n_samples, n_samples)
68
+ @classes.each do |label|
69
+ pos = y.eq(label)
70
+ n_class_samples = pos.count
71
+ pos_vec = Numo::DFloat.cast(pos)
72
+ pos_mat = pos_vec.outer(pos_vec)
73
+ class_mat += pos_mat
74
+ within_weight_mat += pos_mat * 1.fdiv(n_class_samples)
75
+ end
76
+
77
+ mixture_weight_mat = ((affinity_mat - 1) / n_samples) * class_mat + 1.fdiv(n_samples)
78
+ within_weight_mat *= affinity_mat
79
+ mixture_weight_mat = mixture_weight_mat.sum(axis: 1).diag - mixture_weight_mat
80
+ within_weight_mat = within_weight_mat.sum(axis: 1).diag - within_weight_mat
81
+
82
+ # calculate components.
83
+ mixture_mat = x.transpose.dot(mixture_weight_mat.dot(x))
84
+ within_mat = x.transpose.dot(within_weight_mat.dot(x))
85
+ _, evecs = Numo::Linalg.eigh(mixture_mat, within_mat, vals_range: (n_features - n_components)...n_features)
86
+ comps = evecs.reverse(1).transpose.dup
87
+ @components = n_components == 1 ? comps[0, true].dup : comps.dup
88
+ self
89
+ end
90
+
91
+ # Fit the model with training data, and then transform them with the learned model.
92
+ #
93
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
94
+ # @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
95
+ # @return [Numo::DFloat] (shape: [n_samples, n_components]) The transformed data
96
+ def fit_transform(x, y)
97
+ x = Rumale::Validation.check_convert_sample_array(x)
98
+ y = Rumale::Validation.check_convert_label_array(y)
99
+ Rumale::Validation.check_sample_size(x, y)
100
+
101
+ fit(x, y).transform(x)
102
+ end
103
+
104
+ # Transform the given data with the learned model.
105
+ #
106
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The data to be transformed with the learned model.
107
+ # @return [Numo::DFloat] (shape: [n_samples, n_components]) The transformed data.
108
+ def transform(x)
109
+ x = Rumale::Validation.check_convert_sample_array(x)
110
+
111
+ x.dot(@components.transpose)
112
+ end
113
+ end
114
+ end
115
+ end
@@ -5,6 +5,6 @@ module Rumale
5
5
  # Module for metric learning algorithms.
6
6
  module MetricLearning
7
7
  # @!visibility private
8
- VERSION = '0.25.0'
8
+ VERSION = '0.26.0'
9
9
  end
10
10
  end
@@ -3,6 +3,7 @@
3
3
  require 'numo/narray'
4
4
 
5
5
  require_relative 'metric_learning/fisher_discriminant_analysis'
6
+ require_relative 'metric_learning/local_fisher_discriminant_analysis'
6
7
  require_relative 'metric_learning/mlkr'
7
8
  require_relative 'metric_learning/neighbourhood_component_analysis'
8
9
  require_relative 'metric_learning/version'
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: rumale-metric_learning
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.25.0
4
+ version: 0.26.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2023-01-18 00:00:00.000000000 Z
11
+ date: 2023-02-19 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: lbfgsb
@@ -44,28 +44,28 @@ dependencies:
44
44
  requirements:
45
45
  - - "~>"
46
46
  - !ruby/object:Gem::Version
47
- version: 0.25.0
47
+ version: 0.26.0
48
48
  type: :runtime
49
49
  prerelease: false
50
50
  version_requirements: !ruby/object:Gem::Requirement
51
51
  requirements:
52
52
  - - "~>"
53
53
  - !ruby/object:Gem::Version
54
- version: 0.25.0
54
+ version: 0.26.0
55
55
  - !ruby/object:Gem::Dependency
56
56
  name: rumale-decomposition
57
57
  requirement: !ruby/object:Gem::Requirement
58
58
  requirements:
59
59
  - - "~>"
60
60
  - !ruby/object:Gem::Version
61
- version: 0.25.0
61
+ version: 0.26.0
62
62
  type: :runtime
63
63
  prerelease: false
64
64
  version_requirements: !ruby/object:Gem::Requirement
65
65
  requirements:
66
66
  - - "~>"
67
67
  - !ruby/object:Gem::Version
68
- version: 0.25.0
68
+ version: 0.26.0
69
69
  description: |
70
70
  Rumale::MetricLearning provides metric learning algorithms,
71
71
  such as Fisher Discriminant Analysis and Neighboourhood Component Analysis
@@ -80,6 +80,7 @@ files:
80
80
  - README.md
81
81
  - lib/rumale/metric_learning.rb
82
82
  - lib/rumale/metric_learning/fisher_discriminant_analysis.rb
83
+ - lib/rumale/metric_learning/local_fisher_discriminant_analysis.rb
83
84
  - lib/rumale/metric_learning/mlkr.rb
84
85
  - lib/rumale/metric_learning/neighbourhood_component_analysis.rb
85
86
  - lib/rumale/metric_learning/version.rb