rumale-metric_learning 0.24.0 → 0.26.0
Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 7f20a7e0bf72c2a9869a837c493121ce2ec70132d9818f5bfca3dd2b8dd7db94
|
4
|
+
data.tar.gz: af5c7339a4ef9164223c18b6e18c9453b081b90be787cc9a23ccb4fc9493566f
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: ff8a56b94c0834938085284cbc4290390554ebbe030a0a69dd19fc3efc0ff198c35686656e93e0b4b734c120ce88b4da4a7ef3180128b87bf0acff12c56eb2cb
|
7
|
+
data.tar.gz: 96eddc81041593518dd8139114e89b59634d658fc20bad5ebdde297a5517914b561182028fe18359da3df83b9c204f6aecd7262208c28fd1caba5423bdfa36cb
|
data/LICENSE.txt
CHANGED
@@ -0,0 +1,115 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'rumale/base/estimator'
|
4
|
+
require 'rumale/base/transformer'
|
5
|
+
require 'rumale/validation'
|
6
|
+
|
7
|
+
module Rumale
|
8
|
+
module MetricLearning
|
9
|
+
# LocalFisherDiscriminantAnalysis is a class that implements Local Fisher Discriminant Analysis.
|
10
|
+
#
|
11
|
+
# @example
|
12
|
+
# require 'rumale/metric_learning/local_fisher_discriminant_analysis'
|
13
|
+
#
|
14
|
+
# transformer = Rumale::MetricLearning::LocalFisherDiscriminantAnalysis.new
|
15
|
+
# transformer.fit(training_samples, traininig_labels)
|
16
|
+
# low_samples = transformer.transform(testing_samples)
|
17
|
+
#
|
18
|
+
# *Reference*
|
19
|
+
# - Sugiyama, M., "Local Fisher Discriminant Analysis for Supervised Dimensionality Reduction," Proc. ICML'06, pp. 905--912, 2006.
|
20
|
+
class LocalFisherDiscriminantAnalysis < ::Rumale::Base::Estimator
|
21
|
+
include ::Rumale::Base::Transformer
|
22
|
+
|
23
|
+
# Returns the transform matrix.
|
24
|
+
# @return [Numo::DFloat] (shape: [n_components, n_features])
|
25
|
+
attr_reader :components
|
26
|
+
|
27
|
+
# Return the class labels.
|
28
|
+
# @return [Numo::Int32] (shape: [n_classes])
|
29
|
+
attr_reader :classes
|
30
|
+
|
31
|
+
# Create a new transformer with LocalFisherDiscriminantAnalysis.
|
32
|
+
#
|
33
|
+
# @param n_components [Integer] The number of components.
|
34
|
+
# @param gamma [Float] The parameter of rbf kernel, if nil it is 1 / n_features.
|
35
|
+
def initialize(n_components: nil, gamma: nil)
|
36
|
+
super()
|
37
|
+
@params = {
|
38
|
+
n_components: n_components,
|
39
|
+
gamma: gamma
|
40
|
+
}
|
41
|
+
end
|
42
|
+
|
43
|
+
# Fit the model with given training data.
|
44
|
+
#
|
45
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
|
46
|
+
# @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
|
47
|
+
# @return [LocalFisherDiscriminantAnalysis] The learned classifier itself.
|
48
|
+
def fit(x, y)
|
49
|
+
unless enable_linalg?(warning: false)
|
50
|
+
raise 'LocalFisherDiscriminatAnalysis#fit requires Numo::Linalg but that is not loaded.'
|
51
|
+
end
|
52
|
+
|
53
|
+
x = Rumale::Validation.check_convert_sample_array(x)
|
54
|
+
y = Rumale::Validation.check_convert_label_array(y)
|
55
|
+
Rumale::Validation.check_sample_size(x, y)
|
56
|
+
|
57
|
+
# initialize some variables.
|
58
|
+
n_samples, n_features = x.shape
|
59
|
+
@classes = Numo::Int32[*y.to_a.uniq.sort]
|
60
|
+
n_components = @params[:n_components] || n_features
|
61
|
+
@params[:gamma] ||= 1.fdiv(n_features)
|
62
|
+
affinity_mat = Rumale::PairwiseMetric.rbf_kernel(x, nil, @params[:gamma])
|
63
|
+
affinity_mat[affinity_mat.diag_indices] = 1.0
|
64
|
+
|
65
|
+
# calculate within and mixture scatter matricies.
|
66
|
+
class_mat = Numo::DFloat.zeros(n_samples, n_samples)
|
67
|
+
within_weight_mat = Numo::DFloat.zeros(n_samples, n_samples)
|
68
|
+
@classes.each do |label|
|
69
|
+
pos = y.eq(label)
|
70
|
+
n_class_samples = pos.count
|
71
|
+
pos_vec = Numo::DFloat.cast(pos)
|
72
|
+
pos_mat = pos_vec.outer(pos_vec)
|
73
|
+
class_mat += pos_mat
|
74
|
+
within_weight_mat += pos_mat * 1.fdiv(n_class_samples)
|
75
|
+
end
|
76
|
+
|
77
|
+
mixture_weight_mat = ((affinity_mat - 1) / n_samples) * class_mat + 1.fdiv(n_samples)
|
78
|
+
within_weight_mat *= affinity_mat
|
79
|
+
mixture_weight_mat = mixture_weight_mat.sum(axis: 1).diag - mixture_weight_mat
|
80
|
+
within_weight_mat = within_weight_mat.sum(axis: 1).diag - within_weight_mat
|
81
|
+
|
82
|
+
# calculate components.
|
83
|
+
mixture_mat = x.transpose.dot(mixture_weight_mat.dot(x))
|
84
|
+
within_mat = x.transpose.dot(within_weight_mat.dot(x))
|
85
|
+
_, evecs = Numo::Linalg.eigh(mixture_mat, within_mat, vals_range: (n_features - n_components)...n_features)
|
86
|
+
comps = evecs.reverse(1).transpose.dup
|
87
|
+
@components = n_components == 1 ? comps[0, true].dup : comps.dup
|
88
|
+
self
|
89
|
+
end
|
90
|
+
|
91
|
+
# Fit the model with training data, and then transform them with the learned model.
|
92
|
+
#
|
93
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
|
94
|
+
# @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
|
95
|
+
# @return [Numo::DFloat] (shape: [n_samples, n_components]) The transformed data
|
96
|
+
def fit_transform(x, y)
|
97
|
+
x = Rumale::Validation.check_convert_sample_array(x)
|
98
|
+
y = Rumale::Validation.check_convert_label_array(y)
|
99
|
+
Rumale::Validation.check_sample_size(x, y)
|
100
|
+
|
101
|
+
fit(x, y).transform(x)
|
102
|
+
end
|
103
|
+
|
104
|
+
# Transform the given data with the learned model.
|
105
|
+
#
|
106
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The data to be transformed with the learned model.
|
107
|
+
# @return [Numo::DFloat] (shape: [n_samples, n_components]) The transformed data.
|
108
|
+
def transform(x)
|
109
|
+
x = Rumale::Validation.check_convert_sample_array(x)
|
110
|
+
|
111
|
+
x.dot(@components.transpose)
|
112
|
+
end
|
113
|
+
end
|
114
|
+
end
|
115
|
+
end
|
@@ -3,6 +3,7 @@
|
|
3
3
|
require 'numo/narray'
|
4
4
|
|
5
5
|
require_relative 'metric_learning/fisher_discriminant_analysis'
|
6
|
+
require_relative 'metric_learning/local_fisher_discriminant_analysis'
|
6
7
|
require_relative 'metric_learning/mlkr'
|
7
8
|
require_relative 'metric_learning/neighbourhood_component_analysis'
|
8
9
|
require_relative 'metric_learning/version'
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: rumale-metric_learning
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.26.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- yoshoku
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date:
|
11
|
+
date: 2023-02-19 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: lbfgsb
|
@@ -44,28 +44,28 @@ dependencies:
|
|
44
44
|
requirements:
|
45
45
|
- - "~>"
|
46
46
|
- !ruby/object:Gem::Version
|
47
|
-
version: 0.
|
47
|
+
version: 0.26.0
|
48
48
|
type: :runtime
|
49
49
|
prerelease: false
|
50
50
|
version_requirements: !ruby/object:Gem::Requirement
|
51
51
|
requirements:
|
52
52
|
- - "~>"
|
53
53
|
- !ruby/object:Gem::Version
|
54
|
-
version: 0.
|
54
|
+
version: 0.26.0
|
55
55
|
- !ruby/object:Gem::Dependency
|
56
56
|
name: rumale-decomposition
|
57
57
|
requirement: !ruby/object:Gem::Requirement
|
58
58
|
requirements:
|
59
59
|
- - "~>"
|
60
60
|
- !ruby/object:Gem::Version
|
61
|
-
version: 0.
|
61
|
+
version: 0.26.0
|
62
62
|
type: :runtime
|
63
63
|
prerelease: false
|
64
64
|
version_requirements: !ruby/object:Gem::Requirement
|
65
65
|
requirements:
|
66
66
|
- - "~>"
|
67
67
|
- !ruby/object:Gem::Version
|
68
|
-
version: 0.
|
68
|
+
version: 0.26.0
|
69
69
|
description: |
|
70
70
|
Rumale::MetricLearning provides metric learning algorithms,
|
71
71
|
such as Fisher Discriminant Analysis and Neighboourhood Component Analysis
|
@@ -80,6 +80,7 @@ files:
|
|
80
80
|
- README.md
|
81
81
|
- lib/rumale/metric_learning.rb
|
82
82
|
- lib/rumale/metric_learning/fisher_discriminant_analysis.rb
|
83
|
+
- lib/rumale/metric_learning/local_fisher_discriminant_analysis.rb
|
83
84
|
- lib/rumale/metric_learning/mlkr.rb
|
84
85
|
- lib/rumale/metric_learning/neighbourhood_component_analysis.rb
|
85
86
|
- lib/rumale/metric_learning/version.rb
|