rumale-metric_learning 0.24.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/LICENSE.txt +27 -0
- data/README.md +35 -0
- data/lib/rumale/metric_learning/fisher_discriminant_analysis.rb +118 -0
- data/lib/rumale/metric_learning/mlkr.rb +162 -0
- data/lib/rumale/metric_learning/neighbourhood_component_analysis.rb +167 -0
- data/lib/rumale/metric_learning/version.rb +10 -0
- data/lib/rumale/metric_learning.rb +8 -0
- metadata +114 -0
checksums.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
1
|
+
---
|
2
|
+
SHA256:
|
3
|
+
metadata.gz: de8906421117ba7ceb261e21f9a469cdda66c4a582acddef8b7c25d9e81b5b2c
|
4
|
+
data.tar.gz: 1f89eca43ee6a34c4bfecb3779d6b7e2ba210833834e46105ab9d56cfbe3ef91
|
5
|
+
SHA512:
|
6
|
+
metadata.gz: 6cd556eb2d5060f16f766fae6595b3cf302723f38788ab67680ad053306642a86ef8745e100d387059cc304ad1f07cc9e7e6493bddaa6bc9e066c2ebda5c0772
|
7
|
+
data.tar.gz: 77afbc98ad84566a932193ddb3c774b5edff63a7b7747958cf651ecbff41c1826e0571e9f6232e5261520ecf7545c9e3e8195680ca1f18aff5506233d0d5b682
|
data/LICENSE.txt
ADDED
@@ -0,0 +1,27 @@
|
|
1
|
+
Copyright (c) 2022 Atsushi Tatsuma
|
2
|
+
All rights reserved.
|
3
|
+
|
4
|
+
Redistribution and use in source and binary forms, with or without
|
5
|
+
modification, are permitted provided that the following conditions are met:
|
6
|
+
|
7
|
+
* Redistributions of source code must retain the above copyright notice, this
|
8
|
+
list of conditions and the following disclaimer.
|
9
|
+
|
10
|
+
* Redistributions in binary form must reproduce the above copyright notice,
|
11
|
+
this list of conditions and the following disclaimer in the documentation
|
12
|
+
and/or other materials provided with the distribution.
|
13
|
+
|
14
|
+
* Neither the name of the copyright holder nor the names of its
|
15
|
+
contributors may be used to endorse or promote products derived from
|
16
|
+
this software without specific prior written permission.
|
17
|
+
|
18
|
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
19
|
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
20
|
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
21
|
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
22
|
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
23
|
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
24
|
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
25
|
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
26
|
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
27
|
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
data/README.md
ADDED
@@ -0,0 +1,35 @@
|
|
1
|
+
# Rumale::MetricLearning
|
2
|
+
|
3
|
+
[](https://badge.fury.io/rb/rumale-metric_learning)
|
4
|
+
[](https://github.com/yoshoku/rumale/blob/main/rumale-metric_learning/LICENSE.txt)
|
5
|
+
[](https://yoshoku.github.io/rumale/doc/Rumale/MetricLearning.html)
|
6
|
+
|
7
|
+
Rumale is a machine learning library in Ruby.
|
8
|
+
Rumale::MetricLearning provides metric learning algorithms,
|
9
|
+
such as Fisher Discriminant Analysis and Neighboourhood Component Analysis
|
10
|
+
with Rumale interface.
|
11
|
+
|
12
|
+
|
13
|
+
## Installation
|
14
|
+
|
15
|
+
Add this line to your application's Gemfile:
|
16
|
+
|
17
|
+
```ruby
|
18
|
+
gem 'rumale-metric_learning'
|
19
|
+
```
|
20
|
+
|
21
|
+
And then execute:
|
22
|
+
|
23
|
+
$ bundle install
|
24
|
+
|
25
|
+
Or install it yourself as:
|
26
|
+
|
27
|
+
$ gem install rumale-metric_learning
|
28
|
+
|
29
|
+
## Documentation
|
30
|
+
|
31
|
+
- [Rumale API Documentation - MetricLearning](https://yoshoku.github.io/rumale/doc/Rumale/MetricLearning.html)
|
32
|
+
|
33
|
+
## License
|
34
|
+
|
35
|
+
The gem is available as open source under the terms of the [BSD-3-Clause License](https://opensource.org/licenses/BSD-3-Clause).
|
@@ -0,0 +1,118 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'rumale/base/estimator'
|
4
|
+
require 'rumale/base/transformer'
|
5
|
+
require 'rumale/validation'
|
6
|
+
|
7
|
+
module Rumale
|
8
|
+
module MetricLearning
|
9
|
+
# FisherDiscriminantAnalysis is a class that implements Fisher Discriminant Analysis.
|
10
|
+
#
|
11
|
+
# @example
|
12
|
+
# require 'rumale/metric_learning/fisher_discriminant_analysis'
|
13
|
+
#
|
14
|
+
# transformer = Rumale::MetricLearning::FisherDiscriminantAnalysis.new
|
15
|
+
# transformer.fit(training_samples, traininig_labels)
|
16
|
+
# low_samples = transformer.transform(testing_samples)
|
17
|
+
#
|
18
|
+
# *Reference*
|
19
|
+
# - Fisher, R. A., "The use of multiple measurements in taxonomic problems," Annals of Eugenics, vol. 7, pp. 179--188, 1936.
|
20
|
+
# - Sugiyama, M., "Local Fisher Discriminant Analysis for Supervised Dimensionality Reduction," Proc. ICML'06, pp. 905--912, 2006.
|
21
|
+
class FisherDiscriminantAnalysis < ::Rumale::Base::Estimator
|
22
|
+
include ::Rumale::Base::Transformer
|
23
|
+
|
24
|
+
# Returns the transform matrix.
|
25
|
+
# @return [Numo::DFloat] (shape: [n_components, n_features])
|
26
|
+
attr_reader :components
|
27
|
+
|
28
|
+
# Returns the mean vector.
|
29
|
+
# @return [Numo::DFloat] (shape: [n_features])
|
30
|
+
attr_reader :mean
|
31
|
+
|
32
|
+
# Returns the class mean vectors.
|
33
|
+
# @return [Numo::DFloat] (shape: [n_classes, n_features])
|
34
|
+
attr_reader :class_means
|
35
|
+
|
36
|
+
# Return the class labels.
|
37
|
+
# @return [Numo::Int32] (shape: [n_classes])
|
38
|
+
attr_reader :classes
|
39
|
+
|
40
|
+
# Create a new transformer with FisherDiscriminantAnalysis.
|
41
|
+
#
|
42
|
+
# @param n_components [Integer] The number of components.
|
43
|
+
# If nil is given, the number of components will be set to [n_features, n_classes - 1].min
|
44
|
+
def initialize(n_components: nil)
|
45
|
+
super()
|
46
|
+
@params = { n_components: n_components }
|
47
|
+
end
|
48
|
+
|
49
|
+
# Fit the model with given training data.
|
50
|
+
#
|
51
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
|
52
|
+
# @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
|
53
|
+
# @return [FisherDiscriminantAnalysis] The learned classifier itself.
|
54
|
+
def fit(x, y)
|
55
|
+
x = ::Rumale::Validation.check_convert_sample_array(x)
|
56
|
+
y = ::Rumale::Validation.check_convert_label_array(y)
|
57
|
+
::Rumale::Validation.check_sample_size(x, y)
|
58
|
+
unless enable_linalg?(warning: false)
|
59
|
+
raise 'FisherDiscriminatAnalysis#fit requires Numo::Linalg but that is not loaded.'
|
60
|
+
end
|
61
|
+
|
62
|
+
# initialize some variables.
|
63
|
+
n_features = x.shape[1]
|
64
|
+
@classes = Numo::Int32[*y.to_a.uniq.sort]
|
65
|
+
n_classes = @classes.size
|
66
|
+
n_components = if @params[:n_components].nil?
|
67
|
+
[n_features, n_classes - 1].min
|
68
|
+
else
|
69
|
+
[n_features, @params[:n_components]].min
|
70
|
+
end
|
71
|
+
|
72
|
+
# calculate within and between scatter matricies.
|
73
|
+
within_mat = Numo::DFloat.zeros(n_features, n_features)
|
74
|
+
between_mat = Numo::DFloat.zeros(n_features, n_features)
|
75
|
+
@class_means = Numo::DFloat.zeros(n_classes, n_features)
|
76
|
+
@mean = x.mean(0)
|
77
|
+
@classes.each_with_index do |label, i|
|
78
|
+
mask_vec = y.eq(label)
|
79
|
+
sz_class = mask_vec.count
|
80
|
+
class_samples = x[mask_vec, true]
|
81
|
+
class_mean = class_samples.mean(0)
|
82
|
+
within_mat += (class_samples - class_mean).transpose.dot(class_samples - class_mean)
|
83
|
+
between_mat += sz_class * (class_mean - @mean).expand_dims(1) * (class_mean - @mean)
|
84
|
+
@class_means[i, true] = class_mean
|
85
|
+
end
|
86
|
+
|
87
|
+
# calculate components.
|
88
|
+
_, evecs = Numo::Linalg.eigh(between_mat, within_mat, vals_range: (n_features - n_components)...n_features)
|
89
|
+
comps = evecs.reverse(1).transpose.dup
|
90
|
+
@components = n_components == 1 ? comps[0, true].dup : comps.dup
|
91
|
+
self
|
92
|
+
end
|
93
|
+
|
94
|
+
# Fit the model with training data, and then transform them with the learned model.
|
95
|
+
#
|
96
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
|
97
|
+
# @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
|
98
|
+
# @return [Numo::DFloat] (shape: [n_samples, n_components]) The transformed data
|
99
|
+
def fit_transform(x, y)
|
100
|
+
x = ::Rumale::Validation.check_convert_sample_array(x)
|
101
|
+
y = ::Rumale::Validation.check_convert_label_array(y)
|
102
|
+
::Rumale::Validation.check_sample_size(x, y)
|
103
|
+
|
104
|
+
fit(x, y).transform(x)
|
105
|
+
end
|
106
|
+
|
107
|
+
# Transform the given data with the learned model.
|
108
|
+
#
|
109
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The data to be transformed with the learned model.
|
110
|
+
# @return [Numo::DFloat] (shape: [n_samples, n_components]) The transformed data.
|
111
|
+
def transform(x)
|
112
|
+
x = ::Rumale::Validation.check_convert_sample_array(x)
|
113
|
+
|
114
|
+
x.dot(@components.transpose)
|
115
|
+
end
|
116
|
+
end
|
117
|
+
end
|
118
|
+
end
|
@@ -0,0 +1,162 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'lbfgsb'
|
4
|
+
|
5
|
+
require 'rumale/base/estimator'
|
6
|
+
require 'rumale/base/transformer'
|
7
|
+
require 'rumale/decomposition/pca'
|
8
|
+
require 'rumale/pairwise_metric'
|
9
|
+
require 'rumale/utils'
|
10
|
+
require 'rumale/validation'
|
11
|
+
|
12
|
+
module Rumale
|
13
|
+
module MetricLearning
|
14
|
+
# MLKR is a class that implements Metric Learning for Kernel Regression.
|
15
|
+
#
|
16
|
+
# @example
|
17
|
+
# require 'rumale/metric_learning/mlkr'
|
18
|
+
#
|
19
|
+
# transformer = Rumale::MetricLearning::MLKR.new
|
20
|
+
# transformer.fit(training_samples, traininig_target_values)
|
21
|
+
# low_samples = transformer.transform(testing_samples)
|
22
|
+
#
|
23
|
+
# *Reference*
|
24
|
+
# - Weinberger, K. Q. and Tesauro, G., "Metric Learning for Kernel Regression," Proc. AISTATS'07, pp. 612--629, 2007.
|
25
|
+
class MLKR < ::Rumale::Base::Estimator
|
26
|
+
include ::Rumale::Base::Transformer
|
27
|
+
|
28
|
+
# Returns the metric components.
|
29
|
+
# @return [Numo::DFloat] (shape: [n_components, n_features])
|
30
|
+
attr_reader :components
|
31
|
+
|
32
|
+
# Return the number of iterations run for optimization
|
33
|
+
# @return [Integer]
|
34
|
+
attr_reader :n_iter
|
35
|
+
|
36
|
+
# Return the random generator.
|
37
|
+
# @return [Random]
|
38
|
+
attr_reader :rng
|
39
|
+
|
40
|
+
# Create a new transformer with MLKR.
|
41
|
+
#
|
42
|
+
# @param n_components [Integer] The number of components.
|
43
|
+
# @param init [String] The initialization method for components ('random' or 'pca').
|
44
|
+
# @param max_iter [Integer] The maximum number of iterations.
|
45
|
+
# @param tol [Float] The tolerance of termination criterion.
|
46
|
+
# This value is given as tol / Lbfgsb::DBL_EPSILON to the factr argument of Lbfgsb.minimize method.
|
47
|
+
# @param verbose [Boolean] The flag indicating whether to output loss during iteration.
|
48
|
+
# If true is given, 'iterate.dat' file is generated by lbfgsb.rb.
|
49
|
+
# @param random_seed [Integer] The seed value using to initialize the random generator.
|
50
|
+
def initialize(n_components: nil, init: 'random', max_iter: 100, tol: 1e-6, verbose: false, random_seed: nil)
|
51
|
+
super()
|
52
|
+
@params = {
|
53
|
+
n_components: n_components,
|
54
|
+
init: init,
|
55
|
+
max_iter: max_iter,
|
56
|
+
tol: tol,
|
57
|
+
verbose: verbose,
|
58
|
+
random_seed: random_seed || srand
|
59
|
+
}
|
60
|
+
@rng = Random.new(@params[:random_seed])
|
61
|
+
end
|
62
|
+
|
63
|
+
# Fit the model with given training data.
|
64
|
+
#
|
65
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
|
66
|
+
# @param y [Numo::DFloat] (shape: [n_samples]) The target values to be used for fitting the model.
|
67
|
+
# @return [MLKR] The learned classifier itself.
|
68
|
+
def fit(x, y)
|
69
|
+
x = ::Rumale::Validation.check_convert_sample_array(x)
|
70
|
+
y = ::Rumale::Validation.check_convert_target_value_array(y)
|
71
|
+
::Rumale::Validation.check_sample_size(x, y)
|
72
|
+
|
73
|
+
n_features = x.shape[1]
|
74
|
+
n_components = if @params[:n_components].nil?
|
75
|
+
n_features
|
76
|
+
else
|
77
|
+
[n_features, @params[:n_components]].min
|
78
|
+
end
|
79
|
+
@components, @n_iter = optimize_components(x, y, n_features, n_components)
|
80
|
+
@prototypes = x.dot(@components.transpose)
|
81
|
+
@values = y
|
82
|
+
self
|
83
|
+
end
|
84
|
+
|
85
|
+
# Fit the model with training data, and then transform them with the learned model.
|
86
|
+
#
|
87
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
|
88
|
+
# @param y [Numo::DFloat] (shape: [n_samples]) The target values to be used for fitting the model.
|
89
|
+
# @return [Numo::DFloat] (shape: [n_samples, n_components]) The transformed data
|
90
|
+
def fit_transform(x, y)
|
91
|
+
x = ::Rumale::Validation.check_convert_sample_array(x)
|
92
|
+
y = ::Rumale::Validation.check_convert_target_value_array(y)
|
93
|
+
::Rumale::Validation.check_sample_size(x, y)
|
94
|
+
|
95
|
+
fit(x, y).transform(x)
|
96
|
+
end
|
97
|
+
|
98
|
+
# Transform the given data with the learned model.
|
99
|
+
#
|
100
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The data to be transformed with the learned model.
|
101
|
+
# @return [Numo::DFloat] (shape: [n_samples, n_components]) The transformed data.
|
102
|
+
def transform(x)
|
103
|
+
x = ::Rumale::Validation.check_convert_sample_array(x)
|
104
|
+
|
105
|
+
x.dot(@components.transpose)
|
106
|
+
end
|
107
|
+
|
108
|
+
private
|
109
|
+
|
110
|
+
def init_components(x, n_features, n_components)
|
111
|
+
if @params[:init] == 'pca'
|
112
|
+
pca = ::Rumale::Decomposition::PCA.new(n_components: n_components)
|
113
|
+
pca.fit(x).components.flatten.dup
|
114
|
+
else
|
115
|
+
::Rumale::Utils.rand_normal([n_features, n_components], @rng.dup).flatten.dup
|
116
|
+
end
|
117
|
+
end
|
118
|
+
|
119
|
+
def optimize_components(x, y, n_features, n_components)
|
120
|
+
# initialize components.
|
121
|
+
comp_init = init_components(x, n_features, n_components)
|
122
|
+
# initialize optimization results.
|
123
|
+
res = {}
|
124
|
+
res[:x] = comp_init
|
125
|
+
res[:n_iter] = 0
|
126
|
+
# perform optimization.
|
127
|
+
verbose = @params[:verbose] ? 1 : -1
|
128
|
+
res = Lbfgsb.minimize(
|
129
|
+
fnc: method(:mlkr_fnc), jcb: true, x_init: comp_init, args: [x, y],
|
130
|
+
maxiter: @params[:max_iter], factr: @params[:tol] / Lbfgsb::DBL_EPSILON, verbose: verbose
|
131
|
+
)
|
132
|
+
# return the results.
|
133
|
+
n_iter = res[:n_iter]
|
134
|
+
comps = n_components == 1 ? res[:x].dup : res[:x].reshape(n_components, n_features)
|
135
|
+
[comps, n_iter]
|
136
|
+
end
|
137
|
+
|
138
|
+
def mlkr_fnc(w, x, y)
|
139
|
+
# initialize some variables.
|
140
|
+
n_features = x.shape[1]
|
141
|
+
n_components = w.size / n_features
|
142
|
+
# projection.
|
143
|
+
w = w.reshape(n_components, n_features)
|
144
|
+
z = x.dot(w.transpose)
|
145
|
+
# predict values.
|
146
|
+
kernel_mat = Numo::NMath.exp(-::Rumale::PairwiseMetric.squared_error(z))
|
147
|
+
kernel_mat[kernel_mat.diag_indices] = 0.0
|
148
|
+
norm = kernel_mat.sum(axis: 1)
|
149
|
+
norm[norm.eq(0)] = 1
|
150
|
+
y_pred = kernel_mat.dot(y) / norm
|
151
|
+
# calculate loss.
|
152
|
+
y_diff = y_pred - y
|
153
|
+
loss = (y_diff**2).sum
|
154
|
+
# calculate gradient.
|
155
|
+
weight_mat = y_diff * y_diff.expand_dims(1) * kernel_mat
|
156
|
+
weight_mat = weight_mat.sum(axis: 0).diag - weight_mat
|
157
|
+
gradient = 8 * z.transpose.dot(weight_mat).dot(x)
|
158
|
+
[loss, gradient.flatten.dup]
|
159
|
+
end
|
160
|
+
end
|
161
|
+
end
|
162
|
+
end
|
@@ -0,0 +1,167 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'lbfgsb'
|
4
|
+
|
5
|
+
require 'rumale/base/estimator'
|
6
|
+
require 'rumale/base/transformer'
|
7
|
+
require 'rumale/utils'
|
8
|
+
require 'rumale/validation'
|
9
|
+
require 'rumale/pairwise_metric'
|
10
|
+
|
11
|
+
module Rumale
|
12
|
+
module MetricLearning
|
13
|
+
# NeighbourhoodComponentAnalysis is a class that implements Neighbourhood Component Analysis.
|
14
|
+
#
|
15
|
+
# @example
|
16
|
+
# require 'rumale/metric_learning/neighbourhood_component_analysis'
|
17
|
+
#
|
18
|
+
# transformer = Rumale::MetricLearning::NeighbourhoodComponentAnalysis.new
|
19
|
+
# transformer.fit(training_samples, traininig_labels)
|
20
|
+
# low_samples = transformer.transform(testing_samples)
|
21
|
+
#
|
22
|
+
# *Reference*
|
23
|
+
# - Goldberger, J., Roweis, S., Hinton, G., and Salakhutdinov, R., "Neighbourhood Component Analysis," Advances in NIPS'17, pp. 513--520, 2005.
|
24
|
+
class NeighbourhoodComponentAnalysis < ::Rumale::Base::Estimator
|
25
|
+
include ::Rumale::Base::Transformer
|
26
|
+
|
27
|
+
# Returns the neighbourhood components.
|
28
|
+
# @return [Numo::DFloat] (shape: [n_components, n_features])
|
29
|
+
attr_reader :components
|
30
|
+
|
31
|
+
# Return the number of iterations run for optimization
|
32
|
+
# @return [Integer]
|
33
|
+
attr_reader :n_iter
|
34
|
+
|
35
|
+
# Return the random generator.
|
36
|
+
# @return [Random]
|
37
|
+
attr_reader :rng
|
38
|
+
|
39
|
+
# Create a new transformer with NeighbourhoodComponentAnalysis.
|
40
|
+
#
|
41
|
+
# @param n_components [Integer] The number of components.
|
42
|
+
# @param init [String] The initialization method for components ('random' or 'pca').
|
43
|
+
# @param max_iter [Integer] The maximum number of iterations.
|
44
|
+
# @param tol [Float] The tolerance of termination criterion.
|
45
|
+
# This value is given as tol / Lbfgsb::DBL_EPSILON to the factr argument of Lbfgsb.minimize method.
|
46
|
+
# @param verbose [Boolean] The flag indicating whether to output loss during iteration.
|
47
|
+
# If true is given, 'iterate.dat' file is generated by lbfgsb.rb.
|
48
|
+
# @param random_seed [Integer] The seed value using to initialize the random generator.
|
49
|
+
def initialize(n_components: nil, init: 'random', max_iter: 100, tol: 1e-6, verbose: false, random_seed: nil)
|
50
|
+
super()
|
51
|
+
@params = {
|
52
|
+
n_components: n_components,
|
53
|
+
init: init,
|
54
|
+
max_iter: max_iter,
|
55
|
+
tol: tol,
|
56
|
+
verbose: verbose,
|
57
|
+
random_seed: random_seed || srand
|
58
|
+
}
|
59
|
+
@rng = Random.new(@params[:random_seed])
|
60
|
+
end
|
61
|
+
|
62
|
+
# Fit the model with given training data.
|
63
|
+
#
|
64
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
|
65
|
+
# @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
|
66
|
+
# @return [NeighbourhoodComponentAnalysis] The learned classifier itself.
|
67
|
+
def fit(x, y)
|
68
|
+
x = ::Rumale::Validation.check_convert_sample_array(x)
|
69
|
+
y = ::Rumale::Validation.check_convert_label_array(y)
|
70
|
+
::Rumale::Validation.check_sample_size(x, y)
|
71
|
+
|
72
|
+
n_features = x.shape[1]
|
73
|
+
n_components = if @params[:n_components].nil?
|
74
|
+
n_features
|
75
|
+
else
|
76
|
+
[n_features, @params[:n_components]].min
|
77
|
+
end
|
78
|
+
@components, @n_iter = optimize_components(x, y, n_features, n_components)
|
79
|
+
self
|
80
|
+
end
|
81
|
+
|
82
|
+
# Fit the model with training data, and then transform them with the learned model.
|
83
|
+
#
|
84
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
|
85
|
+
# @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
|
86
|
+
# @return [Numo::DFloat] (shape: [n_samples, n_components]) The transformed data
|
87
|
+
def fit_transform(x, y)
|
88
|
+
x = ::Rumale::Validation.check_convert_sample_array(x)
|
89
|
+
y = ::Rumale::Validation.check_convert_label_array(y)
|
90
|
+
::Rumale::Validation.check_sample_size(x, y)
|
91
|
+
|
92
|
+
fit(x, y).transform(x)
|
93
|
+
end
|
94
|
+
|
95
|
+
# Transform the given data with the learned model.
|
96
|
+
#
|
97
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The data to be transformed with the learned model.
|
98
|
+
# @return [Numo::DFloat] (shape: [n_samples, n_components]) The transformed data.
|
99
|
+
def transform(x)
|
100
|
+
x = ::Rumale::Validation.check_convert_sample_array(x)
|
101
|
+
|
102
|
+
x.dot(@components.transpose)
|
103
|
+
end
|
104
|
+
|
105
|
+
private
|
106
|
+
|
107
|
+
def init_components(x, n_features, n_components)
|
108
|
+
if @params[:init] == 'pca'
|
109
|
+
pca = ::Rumale::Decomposition::PCA.new(n_components: n_components)
|
110
|
+
pca.fit(x).components.flatten.dup
|
111
|
+
else
|
112
|
+
::Rumale::Utils.rand_normal([n_features, n_components], @rng.dup).flatten.dup
|
113
|
+
end
|
114
|
+
end
|
115
|
+
|
116
|
+
def optimize_components(x, y, n_features, n_components)
|
117
|
+
# initialize components.
|
118
|
+
comp_init = init_components(x, n_features, n_components)
|
119
|
+
# initialize optimization results.
|
120
|
+
res = {}
|
121
|
+
res[:x] = comp_init
|
122
|
+
res[:n_iter] = 0
|
123
|
+
# perform optimization.
|
124
|
+
verbose = @params[:verbose] ? 1 : -1
|
125
|
+
res = Lbfgsb.minimize(
|
126
|
+
fnc: method(:nca_fnc), jcb: true, x_init: comp_init, args: [x, y],
|
127
|
+
maxiter: @params[:max_iter], factr: @params[:tol] / Lbfgsb::DBL_EPSILON, verbose: verbose
|
128
|
+
)
|
129
|
+
# return the results.
|
130
|
+
n_iter = res[:n_iter]
|
131
|
+
comps = n_components == 1 ? res[:x].dup : res[:x].reshape(n_components, n_features)
|
132
|
+
[comps, n_iter]
|
133
|
+
end
|
134
|
+
|
135
|
+
def nca_fnc(w, x, y)
|
136
|
+
# initialize some variables.
|
137
|
+
n_samples, n_features = x.shape
|
138
|
+
n_components = w.size / n_features
|
139
|
+
# projection.
|
140
|
+
w = w.reshape(n_components, n_features)
|
141
|
+
z = x.dot(w.transpose)
|
142
|
+
# calculate probability matrix.
|
143
|
+
prob_mat = probability_matrix(z)
|
144
|
+
# calculate loss and gradient.
|
145
|
+
# NOTE:
|
146
|
+
# NCA attempts to maximize its objective function.
|
147
|
+
# For the minization algorithm, the objective function value is subtracted from the maixmum value (n_samples).
|
148
|
+
mask_mat = y.expand_dims(1).eq(y)
|
149
|
+
masked_prob_mat = prob_mat * mask_mat
|
150
|
+
loss = n_samples - masked_prob_mat.sum
|
151
|
+
sum_probs = masked_prob_mat.sum(axis: 1)
|
152
|
+
weight_mat = (sum_probs.expand_dims(1) * prob_mat - masked_prob_mat)
|
153
|
+
weight_mat += weight_mat.transpose
|
154
|
+
weight_mat = weight_mat.sum(axis: 0).diag - weight_mat
|
155
|
+
gradient = -2 * z.transpose.dot(weight_mat).dot(x)
|
156
|
+
[loss, gradient.flatten.dup]
|
157
|
+
end
|
158
|
+
|
159
|
+
def probability_matrix(z)
|
160
|
+
prob_mat = Numo::NMath.exp(-::Rumale::PairwiseMetric.squared_error(z))
|
161
|
+
prob_mat[prob_mat.diag_indices] = 0.0
|
162
|
+
prob_mat /= prob_mat.sum(axis: 1).expand_dims(1)
|
163
|
+
prob_mat
|
164
|
+
end
|
165
|
+
end
|
166
|
+
end
|
167
|
+
end
|
@@ -0,0 +1,8 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'numo/narray'
|
4
|
+
|
5
|
+
require_relative 'metric_learning/fisher_discriminant_analysis'
|
6
|
+
require_relative 'metric_learning/mlkr'
|
7
|
+
require_relative 'metric_learning/neighbourhood_component_analysis'
|
8
|
+
require_relative 'metric_learning/version'
|
metadata
ADDED
@@ -0,0 +1,114 @@
|
|
1
|
+
--- !ruby/object:Gem::Specification
|
2
|
+
name: rumale-metric_learning
|
3
|
+
version: !ruby/object:Gem::Version
|
4
|
+
version: 0.24.0
|
5
|
+
platform: ruby
|
6
|
+
authors:
|
7
|
+
- yoshoku
|
8
|
+
autorequire:
|
9
|
+
bindir: exe
|
10
|
+
cert_chain: []
|
11
|
+
date: 2022-12-31 00:00:00.000000000 Z
|
12
|
+
dependencies:
|
13
|
+
- !ruby/object:Gem::Dependency
|
14
|
+
name: lbfgsb
|
15
|
+
requirement: !ruby/object:Gem::Requirement
|
16
|
+
requirements:
|
17
|
+
- - ">="
|
18
|
+
- !ruby/object:Gem::Version
|
19
|
+
version: 0.3.0
|
20
|
+
type: :runtime
|
21
|
+
prerelease: false
|
22
|
+
version_requirements: !ruby/object:Gem::Requirement
|
23
|
+
requirements:
|
24
|
+
- - ">="
|
25
|
+
- !ruby/object:Gem::Version
|
26
|
+
version: 0.3.0
|
27
|
+
- !ruby/object:Gem::Dependency
|
28
|
+
name: numo-narray
|
29
|
+
requirement: !ruby/object:Gem::Requirement
|
30
|
+
requirements:
|
31
|
+
- - ">="
|
32
|
+
- !ruby/object:Gem::Version
|
33
|
+
version: 0.9.1
|
34
|
+
type: :runtime
|
35
|
+
prerelease: false
|
36
|
+
version_requirements: !ruby/object:Gem::Requirement
|
37
|
+
requirements:
|
38
|
+
- - ">="
|
39
|
+
- !ruby/object:Gem::Version
|
40
|
+
version: 0.9.1
|
41
|
+
- !ruby/object:Gem::Dependency
|
42
|
+
name: rumale-core
|
43
|
+
requirement: !ruby/object:Gem::Requirement
|
44
|
+
requirements:
|
45
|
+
- - "~>"
|
46
|
+
- !ruby/object:Gem::Version
|
47
|
+
version: 0.24.0
|
48
|
+
type: :runtime
|
49
|
+
prerelease: false
|
50
|
+
version_requirements: !ruby/object:Gem::Requirement
|
51
|
+
requirements:
|
52
|
+
- - "~>"
|
53
|
+
- !ruby/object:Gem::Version
|
54
|
+
version: 0.24.0
|
55
|
+
- !ruby/object:Gem::Dependency
|
56
|
+
name: rumale-decomposition
|
57
|
+
requirement: !ruby/object:Gem::Requirement
|
58
|
+
requirements:
|
59
|
+
- - "~>"
|
60
|
+
- !ruby/object:Gem::Version
|
61
|
+
version: 0.24.0
|
62
|
+
type: :runtime
|
63
|
+
prerelease: false
|
64
|
+
version_requirements: !ruby/object:Gem::Requirement
|
65
|
+
requirements:
|
66
|
+
- - "~>"
|
67
|
+
- !ruby/object:Gem::Version
|
68
|
+
version: 0.24.0
|
69
|
+
description: |
|
70
|
+
Rumale::MetricLearning provides metric learning algorithms,
|
71
|
+
such as Fisher Discriminant Analysis and Neighboourhood Component Analysis
|
72
|
+
with Rumale interface.
|
73
|
+
email:
|
74
|
+
- yoshoku@outlook.com
|
75
|
+
executables: []
|
76
|
+
extensions: []
|
77
|
+
extra_rdoc_files: []
|
78
|
+
files:
|
79
|
+
- LICENSE.txt
|
80
|
+
- README.md
|
81
|
+
- lib/rumale/metric_learning.rb
|
82
|
+
- lib/rumale/metric_learning/fisher_discriminant_analysis.rb
|
83
|
+
- lib/rumale/metric_learning/mlkr.rb
|
84
|
+
- lib/rumale/metric_learning/neighbourhood_component_analysis.rb
|
85
|
+
- lib/rumale/metric_learning/version.rb
|
86
|
+
homepage: https://github.com/yoshoku/rumale
|
87
|
+
licenses:
|
88
|
+
- BSD-3-Clause
|
89
|
+
metadata:
|
90
|
+
homepage_uri: https://github.com/yoshoku/rumale
|
91
|
+
source_code_uri: https://github.com/yoshoku/rumale/tree/main/rumale-metric_learning
|
92
|
+
changelog_uri: https://github.com/yoshoku/rumale/blob/main/CHANGELOG.md
|
93
|
+
documentation_uri: https://yoshoku.github.io/rumale/doc/
|
94
|
+
rubygems_mfa_required: 'true'
|
95
|
+
post_install_message:
|
96
|
+
rdoc_options: []
|
97
|
+
require_paths:
|
98
|
+
- lib
|
99
|
+
required_ruby_version: !ruby/object:Gem::Requirement
|
100
|
+
requirements:
|
101
|
+
- - ">="
|
102
|
+
- !ruby/object:Gem::Version
|
103
|
+
version: '0'
|
104
|
+
required_rubygems_version: !ruby/object:Gem::Requirement
|
105
|
+
requirements:
|
106
|
+
- - ">="
|
107
|
+
- !ruby/object:Gem::Version
|
108
|
+
version: '0'
|
109
|
+
requirements: []
|
110
|
+
rubygems_version: 3.3.26
|
111
|
+
signing_key:
|
112
|
+
specification_version: 4
|
113
|
+
summary: Rumale::MetricLearning provides metric learning algorithms with Rumale interface.
|
114
|
+
test_files: []
|