rumale-metric_learning 0.25.0 → 0.26.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 7f20a7e0bf72c2a9869a837c493121ce2ec70132d9818f5bfca3dd2b8dd7db94
|
4
|
+
data.tar.gz: af5c7339a4ef9164223c18b6e18c9453b081b90be787cc9a23ccb4fc9493566f
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: ff8a56b94c0834938085284cbc4290390554ebbe030a0a69dd19fc3efc0ff198c35686656e93e0b4b734c120ce88b4da4a7ef3180128b87bf0acff12c56eb2cb
|
7
|
+
data.tar.gz: 96eddc81041593518dd8139114e89b59634d658fc20bad5ebdde297a5517914b561182028fe18359da3df83b9c204f6aecd7262208c28fd1caba5423bdfa36cb
|
@@ -0,0 +1,115 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'rumale/base/estimator'
|
4
|
+
require 'rumale/base/transformer'
|
5
|
+
require 'rumale/validation'
|
6
|
+
|
7
|
+
module Rumale
|
8
|
+
module MetricLearning
|
9
|
+
# LocalFisherDiscriminantAnalysis is a class that implements Local Fisher Discriminant Analysis.
|
10
|
+
#
|
11
|
+
# @example
|
12
|
+
# require 'rumale/metric_learning/local_fisher_discriminant_analysis'
|
13
|
+
#
|
14
|
+
# transformer = Rumale::MetricLearning::LocalFisherDiscriminantAnalysis.new
|
15
|
+
# transformer.fit(training_samples, traininig_labels)
|
16
|
+
# low_samples = transformer.transform(testing_samples)
|
17
|
+
#
|
18
|
+
# *Reference*
|
19
|
+
# - Sugiyama, M., "Local Fisher Discriminant Analysis for Supervised Dimensionality Reduction," Proc. ICML'06, pp. 905--912, 2006.
|
20
|
+
class LocalFisherDiscriminantAnalysis < ::Rumale::Base::Estimator
|
21
|
+
include ::Rumale::Base::Transformer
|
22
|
+
|
23
|
+
# Returns the transform matrix.
|
24
|
+
# @return [Numo::DFloat] (shape: [n_components, n_features])
|
25
|
+
attr_reader :components
|
26
|
+
|
27
|
+
# Return the class labels.
|
28
|
+
# @return [Numo::Int32] (shape: [n_classes])
|
29
|
+
attr_reader :classes
|
30
|
+
|
31
|
+
# Create a new transformer with LocalFisherDiscriminantAnalysis.
|
32
|
+
#
|
33
|
+
# @param n_components [Integer] The number of components.
|
34
|
+
# @param gamma [Float] The parameter of rbf kernel, if nil it is 1 / n_features.
|
35
|
+
def initialize(n_components: nil, gamma: nil)
|
36
|
+
super()
|
37
|
+
@params = {
|
38
|
+
n_components: n_components,
|
39
|
+
gamma: gamma
|
40
|
+
}
|
41
|
+
end
|
42
|
+
|
43
|
+
# Fit the model with given training data.
|
44
|
+
#
|
45
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
|
46
|
+
# @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
|
47
|
+
# @return [LocalFisherDiscriminantAnalysis] The learned classifier itself.
|
48
|
+
def fit(x, y)
|
49
|
+
unless enable_linalg?(warning: false)
|
50
|
+
raise 'LocalFisherDiscriminatAnalysis#fit requires Numo::Linalg but that is not loaded.'
|
51
|
+
end
|
52
|
+
|
53
|
+
x = Rumale::Validation.check_convert_sample_array(x)
|
54
|
+
y = Rumale::Validation.check_convert_label_array(y)
|
55
|
+
Rumale::Validation.check_sample_size(x, y)
|
56
|
+
|
57
|
+
# initialize some variables.
|
58
|
+
n_samples, n_features = x.shape
|
59
|
+
@classes = Numo::Int32[*y.to_a.uniq.sort]
|
60
|
+
n_components = @params[:n_components] || n_features
|
61
|
+
@params[:gamma] ||= 1.fdiv(n_features)
|
62
|
+
affinity_mat = Rumale::PairwiseMetric.rbf_kernel(x, nil, @params[:gamma])
|
63
|
+
affinity_mat[affinity_mat.diag_indices] = 1.0
|
64
|
+
|
65
|
+
# calculate within and mixture scatter matricies.
|
66
|
+
class_mat = Numo::DFloat.zeros(n_samples, n_samples)
|
67
|
+
within_weight_mat = Numo::DFloat.zeros(n_samples, n_samples)
|
68
|
+
@classes.each do |label|
|
69
|
+
pos = y.eq(label)
|
70
|
+
n_class_samples = pos.count
|
71
|
+
pos_vec = Numo::DFloat.cast(pos)
|
72
|
+
pos_mat = pos_vec.outer(pos_vec)
|
73
|
+
class_mat += pos_mat
|
74
|
+
within_weight_mat += pos_mat * 1.fdiv(n_class_samples)
|
75
|
+
end
|
76
|
+
|
77
|
+
mixture_weight_mat = ((affinity_mat - 1) / n_samples) * class_mat + 1.fdiv(n_samples)
|
78
|
+
within_weight_mat *= affinity_mat
|
79
|
+
mixture_weight_mat = mixture_weight_mat.sum(axis: 1).diag - mixture_weight_mat
|
80
|
+
within_weight_mat = within_weight_mat.sum(axis: 1).diag - within_weight_mat
|
81
|
+
|
82
|
+
# calculate components.
|
83
|
+
mixture_mat = x.transpose.dot(mixture_weight_mat.dot(x))
|
84
|
+
within_mat = x.transpose.dot(within_weight_mat.dot(x))
|
85
|
+
_, evecs = Numo::Linalg.eigh(mixture_mat, within_mat, vals_range: (n_features - n_components)...n_features)
|
86
|
+
comps = evecs.reverse(1).transpose.dup
|
87
|
+
@components = n_components == 1 ? comps[0, true].dup : comps.dup
|
88
|
+
self
|
89
|
+
end
|
90
|
+
|
91
|
+
# Fit the model with training data, and then transform them with the learned model.
|
92
|
+
#
|
93
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
|
94
|
+
# @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
|
95
|
+
# @return [Numo::DFloat] (shape: [n_samples, n_components]) The transformed data
|
96
|
+
def fit_transform(x, y)
|
97
|
+
x = Rumale::Validation.check_convert_sample_array(x)
|
98
|
+
y = Rumale::Validation.check_convert_label_array(y)
|
99
|
+
Rumale::Validation.check_sample_size(x, y)
|
100
|
+
|
101
|
+
fit(x, y).transform(x)
|
102
|
+
end
|
103
|
+
|
104
|
+
# Transform the given data with the learned model.
|
105
|
+
#
|
106
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The data to be transformed with the learned model.
|
107
|
+
# @return [Numo::DFloat] (shape: [n_samples, n_components]) The transformed data.
|
108
|
+
def transform(x)
|
109
|
+
x = Rumale::Validation.check_convert_sample_array(x)
|
110
|
+
|
111
|
+
x.dot(@components.transpose)
|
112
|
+
end
|
113
|
+
end
|
114
|
+
end
|
115
|
+
end
|
@@ -3,6 +3,7 @@
|
|
3
3
|
require 'numo/narray'
|
4
4
|
|
5
5
|
require_relative 'metric_learning/fisher_discriminant_analysis'
|
6
|
+
require_relative 'metric_learning/local_fisher_discriminant_analysis'
|
6
7
|
require_relative 'metric_learning/mlkr'
|
7
8
|
require_relative 'metric_learning/neighbourhood_component_analysis'
|
8
9
|
require_relative 'metric_learning/version'
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: rumale-metric_learning
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.26.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- yoshoku
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2023-
|
11
|
+
date: 2023-02-19 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: lbfgsb
|
@@ -44,28 +44,28 @@ dependencies:
|
|
44
44
|
requirements:
|
45
45
|
- - "~>"
|
46
46
|
- !ruby/object:Gem::Version
|
47
|
-
version: 0.
|
47
|
+
version: 0.26.0
|
48
48
|
type: :runtime
|
49
49
|
prerelease: false
|
50
50
|
version_requirements: !ruby/object:Gem::Requirement
|
51
51
|
requirements:
|
52
52
|
- - "~>"
|
53
53
|
- !ruby/object:Gem::Version
|
54
|
-
version: 0.
|
54
|
+
version: 0.26.0
|
55
55
|
- !ruby/object:Gem::Dependency
|
56
56
|
name: rumale-decomposition
|
57
57
|
requirement: !ruby/object:Gem::Requirement
|
58
58
|
requirements:
|
59
59
|
- - "~>"
|
60
60
|
- !ruby/object:Gem::Version
|
61
|
-
version: 0.
|
61
|
+
version: 0.26.0
|
62
62
|
type: :runtime
|
63
63
|
prerelease: false
|
64
64
|
version_requirements: !ruby/object:Gem::Requirement
|
65
65
|
requirements:
|
66
66
|
- - "~>"
|
67
67
|
- !ruby/object:Gem::Version
|
68
|
-
version: 0.
|
68
|
+
version: 0.26.0
|
69
69
|
description: |
|
70
70
|
Rumale::MetricLearning provides metric learning algorithms,
|
71
71
|
such as Fisher Discriminant Analysis and Neighboourhood Component Analysis
|
@@ -80,6 +80,7 @@ files:
|
|
80
80
|
- README.md
|
81
81
|
- lib/rumale/metric_learning.rb
|
82
82
|
- lib/rumale/metric_learning/fisher_discriminant_analysis.rb
|
83
|
+
- lib/rumale/metric_learning/local_fisher_discriminant_analysis.rb
|
83
84
|
- lib/rumale/metric_learning/mlkr.rb
|
84
85
|
- lib/rumale/metric_learning/neighbourhood_component_analysis.rb
|
85
86
|
- lib/rumale/metric_learning/version.rb
|