rumale-metric_learning 0.25.0 → 0.27.0
Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: e7a10afa083e39f86746029f24ca6edd4edc417a045c8608d0c597b788197cf6
|
4
|
+
data.tar.gz: b3b757f6598fbb21d5f6b1383d1bdfa676962ac12136a02c95d5c7dbdb695065
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: a534929e8338c1513872c432576d2bf4a73fe9ff1999cfa997289626e479ab823dcaeff8217f0aeebee0392f64cc6f9ff1790f52c7047c994d55a7d01b55ae16
|
7
|
+
data.tar.gz: a346e8296fd2c1d28ab014d4869b1fb2876dae2b0874729a1768f13896c531c309fab05f8f0a0780d22f5b738e5b1494736cfb126b81ac395638663e73f78944
|
@@ -0,0 +1,115 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'rumale/base/estimator'
|
4
|
+
require 'rumale/base/transformer'
|
5
|
+
require 'rumale/validation'
|
6
|
+
|
7
|
+
module Rumale
|
8
|
+
module MetricLearning
|
9
|
+
# LocalFisherDiscriminantAnalysis is a class that implements Local Fisher Discriminant Analysis.
|
10
|
+
#
|
11
|
+
# @example
|
12
|
+
# require 'rumale/metric_learning/local_fisher_discriminant_analysis'
|
13
|
+
#
|
14
|
+
# transformer = Rumale::MetricLearning::LocalFisherDiscriminantAnalysis.new
|
15
|
+
# transformer.fit(training_samples, traininig_labels)
|
16
|
+
# low_samples = transformer.transform(testing_samples)
|
17
|
+
#
|
18
|
+
# *Reference*
|
19
|
+
# - Sugiyama, M., "Local Fisher Discriminant Analysis for Supervised Dimensionality Reduction," Proc. ICML'06, pp. 905--912, 2006.
|
20
|
+
class LocalFisherDiscriminantAnalysis < ::Rumale::Base::Estimator
|
21
|
+
include ::Rumale::Base::Transformer
|
22
|
+
|
23
|
+
# Returns the transform matrix.
|
24
|
+
# @return [Numo::DFloat] (shape: [n_components, n_features])
|
25
|
+
attr_reader :components
|
26
|
+
|
27
|
+
# Return the class labels.
|
28
|
+
# @return [Numo::Int32] (shape: [n_classes])
|
29
|
+
attr_reader :classes
|
30
|
+
|
31
|
+
# Create a new transformer with LocalFisherDiscriminantAnalysis.
|
32
|
+
#
|
33
|
+
# @param n_components [Integer] The number of components.
|
34
|
+
# @param gamma [Float] The parameter of rbf kernel, if nil it is 1 / n_features.
|
35
|
+
def initialize(n_components: nil, gamma: nil)
|
36
|
+
super()
|
37
|
+
@params = {
|
38
|
+
n_components: n_components,
|
39
|
+
gamma: gamma
|
40
|
+
}
|
41
|
+
end
|
42
|
+
|
43
|
+
# Fit the model with given training data.
|
44
|
+
#
|
45
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
|
46
|
+
# @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
|
47
|
+
# @return [LocalFisherDiscriminantAnalysis] The learned classifier itself.
|
48
|
+
def fit(x, y)
|
49
|
+
unless enable_linalg?(warning: false)
|
50
|
+
raise 'LocalFisherDiscriminatAnalysis#fit requires Numo::Linalg but that is not loaded.'
|
51
|
+
end
|
52
|
+
|
53
|
+
x = Rumale::Validation.check_convert_sample_array(x)
|
54
|
+
y = Rumale::Validation.check_convert_label_array(y)
|
55
|
+
Rumale::Validation.check_sample_size(x, y)
|
56
|
+
|
57
|
+
# initialize some variables.
|
58
|
+
n_samples, n_features = x.shape
|
59
|
+
@classes = Numo::Int32[*y.to_a.uniq.sort]
|
60
|
+
n_components = @params[:n_components] || n_features
|
61
|
+
@params[:gamma] ||= 1.fdiv(n_features)
|
62
|
+
affinity_mat = Rumale::PairwiseMetric.rbf_kernel(x, nil, @params[:gamma])
|
63
|
+
affinity_mat[affinity_mat.diag_indices] = 1.0
|
64
|
+
|
65
|
+
# calculate within and mixture scatter matricies.
|
66
|
+
class_mat = Numo::DFloat.zeros(n_samples, n_samples)
|
67
|
+
within_weight_mat = Numo::DFloat.zeros(n_samples, n_samples)
|
68
|
+
@classes.each do |label|
|
69
|
+
pos = y.eq(label)
|
70
|
+
n_class_samples = pos.count
|
71
|
+
pos_vec = Numo::DFloat.cast(pos)
|
72
|
+
pos_mat = pos_vec.outer(pos_vec)
|
73
|
+
class_mat += pos_mat
|
74
|
+
within_weight_mat += pos_mat * 1.fdiv(n_class_samples)
|
75
|
+
end
|
76
|
+
|
77
|
+
mixture_weight_mat = ((affinity_mat - 1) / n_samples) * class_mat + 1.fdiv(n_samples)
|
78
|
+
within_weight_mat *= affinity_mat
|
79
|
+
mixture_weight_mat = mixture_weight_mat.sum(axis: 1).diag - mixture_weight_mat
|
80
|
+
within_weight_mat = within_weight_mat.sum(axis: 1).diag - within_weight_mat
|
81
|
+
|
82
|
+
# calculate components.
|
83
|
+
mixture_mat = x.transpose.dot(mixture_weight_mat.dot(x))
|
84
|
+
within_mat = x.transpose.dot(within_weight_mat.dot(x))
|
85
|
+
_, evecs = Numo::Linalg.eigh(mixture_mat, within_mat, vals_range: (n_features - n_components)...n_features)
|
86
|
+
comps = evecs.reverse(1).transpose.dup
|
87
|
+
@components = n_components == 1 ? comps[0, true].dup : comps.dup
|
88
|
+
self
|
89
|
+
end
|
90
|
+
|
91
|
+
# Fit the model with training data, and then transform them with the learned model.
|
92
|
+
#
|
93
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
|
94
|
+
# @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
|
95
|
+
# @return [Numo::DFloat] (shape: [n_samples, n_components]) The transformed data
|
96
|
+
def fit_transform(x, y)
|
97
|
+
x = Rumale::Validation.check_convert_sample_array(x)
|
98
|
+
y = Rumale::Validation.check_convert_label_array(y)
|
99
|
+
Rumale::Validation.check_sample_size(x, y)
|
100
|
+
|
101
|
+
fit(x, y).transform(x)
|
102
|
+
end
|
103
|
+
|
104
|
+
# Transform the given data with the learned model.
|
105
|
+
#
|
106
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The data to be transformed with the learned model.
|
107
|
+
# @return [Numo::DFloat] (shape: [n_samples, n_components]) The transformed data.
|
108
|
+
def transform(x)
|
109
|
+
x = Rumale::Validation.check_convert_sample_array(x)
|
110
|
+
|
111
|
+
x.dot(@components.transpose)
|
112
|
+
end
|
113
|
+
end
|
114
|
+
end
|
115
|
+
end
|
@@ -3,6 +3,7 @@
|
|
3
3
|
require 'numo/narray'
|
4
4
|
|
5
5
|
require_relative 'metric_learning/fisher_discriminant_analysis'
|
6
|
+
require_relative 'metric_learning/local_fisher_discriminant_analysis'
|
6
7
|
require_relative 'metric_learning/mlkr'
|
7
8
|
require_relative 'metric_learning/neighbourhood_component_analysis'
|
8
9
|
require_relative 'metric_learning/version'
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: rumale-metric_learning
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.27.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- yoshoku
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2023-
|
11
|
+
date: 2023-08-26 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: lbfgsb
|
@@ -44,28 +44,28 @@ dependencies:
|
|
44
44
|
requirements:
|
45
45
|
- - "~>"
|
46
46
|
- !ruby/object:Gem::Version
|
47
|
-
version: 0.
|
47
|
+
version: 0.27.0
|
48
48
|
type: :runtime
|
49
49
|
prerelease: false
|
50
50
|
version_requirements: !ruby/object:Gem::Requirement
|
51
51
|
requirements:
|
52
52
|
- - "~>"
|
53
53
|
- !ruby/object:Gem::Version
|
54
|
-
version: 0.
|
54
|
+
version: 0.27.0
|
55
55
|
- !ruby/object:Gem::Dependency
|
56
56
|
name: rumale-decomposition
|
57
57
|
requirement: !ruby/object:Gem::Requirement
|
58
58
|
requirements:
|
59
59
|
- - "~>"
|
60
60
|
- !ruby/object:Gem::Version
|
61
|
-
version: 0.
|
61
|
+
version: 0.27.0
|
62
62
|
type: :runtime
|
63
63
|
prerelease: false
|
64
64
|
version_requirements: !ruby/object:Gem::Requirement
|
65
65
|
requirements:
|
66
66
|
- - "~>"
|
67
67
|
- !ruby/object:Gem::Version
|
68
|
-
version: 0.
|
68
|
+
version: 0.27.0
|
69
69
|
description: |
|
70
70
|
Rumale::MetricLearning provides metric learning algorithms,
|
71
71
|
such as Fisher Discriminant Analysis and Neighboourhood Component Analysis
|
@@ -80,6 +80,7 @@ files:
|
|
80
80
|
- README.md
|
81
81
|
- lib/rumale/metric_learning.rb
|
82
82
|
- lib/rumale/metric_learning/fisher_discriminant_analysis.rb
|
83
|
+
- lib/rumale/metric_learning/local_fisher_discriminant_analysis.rb
|
83
84
|
- lib/rumale/metric_learning/mlkr.rb
|
84
85
|
- lib/rumale/metric_learning/neighbourhood_component_analysis.rb
|
85
86
|
- lib/rumale/metric_learning/version.rb
|