rumale-svm 0.8.0 → 0.9.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +3 -0
- data/LICENSE.txt +1 -1
- data/README.md +1 -5
- data/lib/rumale/svm/locally_linear_svc.rb +261 -0
- data/lib/rumale/svm/version.rb +2 -2
- data/lib/rumale/svm.rb +1 -0
- metadata +7 -6
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 8853a60f150df418c30832c66b04f6531a1f3fa16f96ed3cb2286ca88a553dbe
|
4
|
+
data.tar.gz: aded967a3b1f82ded39b07ef3e1f435959f26c8d8392fb212acee3fcd006366f
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 3f2ff48c445b9b5cfc804a17acb3f0c2563512eca01482c81c18fa8e662f9855e45abaede8ab5d7a8b087de84c131f2863f730d0313374b6af97480509fdd18f
|
7
|
+
data.tar.gz: 53747b28e162327cb11a9a5db9343d6baaad09e636909d8149c8ced58ba930a757b6251252968ddc7a7187c651abc56fb0ed046fde3ec6fcdc6a0d718404f420
|
data/CHANGELOG.md
CHANGED
data/LICENSE.txt
CHANGED
data/README.md
CHANGED
@@ -5,13 +5,9 @@
|
|
5
5
|
[![BSD 3-Clause License](https://img.shields.io/badge/License-BSD%203--Clause-orange.svg)](https://github.com/yoshoku/rumale-svm/blob/main/LICENSE.txt)
|
6
6
|
[![Documentation](https://img.shields.io/badge/api-reference-blue.svg)](https://yoshoku.github.io/rumale-svm/doc/)
|
7
7
|
|
8
|
-
Rumale::SVM provides support vector machine algorithms
|
8
|
+
Rumale::SVM provides support vector machine algorithms using
|
9
9
|
[LIBSVM](https://www.csie.ntu.edu.tw/~cjlin/libsvm/) and [LIBLINEAR](https://www.csie.ntu.edu.tw/~cjlin/liblinear/)
|
10
10
|
with [Rumale](https://github.com/yoshoku/rumale) interface.
|
11
|
-
Many machine learning libraries use LIBSVM and LIBLINEAR as background libraries of support vector machine algorithms.
|
12
|
-
On the other hand, Rumale implements support vector machine algorithms based on the mini-batch stochastic gradient descent method
|
13
|
-
implemented in Ruby.
|
14
|
-
Rumale::SVM adds the functions of support vector machine similar to general machine learning libraries to Rumale.
|
15
11
|
|
16
12
|
## Installation
|
17
13
|
|
@@ -0,0 +1,261 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'rumale/base/estimator'
|
4
|
+
require 'rumale/base/classifier'
|
5
|
+
require 'rumale/pairwise_metric'
|
6
|
+
require 'rumale/utils'
|
7
|
+
require 'rumale/validation'
|
8
|
+
|
9
|
+
module Rumale
|
10
|
+
module SVM
|
11
|
+
# LocallyLinearSVC is a class that implements Locally Linear Support Vector Classifier with the squared hinge loss.
|
12
|
+
# This classifier requires Numo::Linalg (or Numo::TinyLinalg) and Lbfgsb gems,
|
13
|
+
# but they are listed in the runtime dependencies of Rumale::SVM.
|
14
|
+
# Therefore, you should install and load Numo::Linalg and Lbfgsb gems explicitly to use this classifier.
|
15
|
+
#
|
16
|
+
# @example
|
17
|
+
# require 'numo/linalg/autoloader'
|
18
|
+
# require 'lbfgsb'
|
19
|
+
# require 'rumale/svm'
|
20
|
+
#
|
21
|
+
# estimator = Rumale::SVM::LocallyLinearSVC.new(reg_param: 1.0, n_anchors: 128)
|
22
|
+
# estimator.fit(training_samples, traininig_labels)
|
23
|
+
# results = estimator.predict(testing_samples)
|
24
|
+
#
|
25
|
+
# *Reference*
|
26
|
+
# - Ladicky, L., and Torr, P H.S., "Locally Linear Support Vector Machines," Proc. ICML'11, pp. 985--992, 2011.
|
27
|
+
class LocallyLinearSVC < Rumale::Base::Estimator
|
28
|
+
include Rumale::Base::Classifier
|
29
|
+
|
30
|
+
# Return the class labels.
|
31
|
+
# @return [Numo::Int32] (size: n_classes)
|
32
|
+
attr_reader :classes
|
33
|
+
|
34
|
+
# Return the anchor vectors.
|
35
|
+
# @return [Numo::DFloat] (shape: [n_anchors, n_features])
|
36
|
+
attr_reader :anchors
|
37
|
+
|
38
|
+
# Return the weight vector.
|
39
|
+
# @return [Numo::DFloat] (shape: [n_classes, n_anchors, n_features])
|
40
|
+
attr_reader :weight_vec
|
41
|
+
|
42
|
+
# Return the bias term (a.k.a. intercept).
|
43
|
+
# @return [Numo::DFloat] (shape: [n_classes, n_anchors])
|
44
|
+
attr_reader :bias_term
|
45
|
+
|
46
|
+
# Create a new classifier with Locally Linear Support Vector Machine.
|
47
|
+
#
|
48
|
+
# @param reg_param [Float] The regularization parameter for weight vector.
|
49
|
+
# @param reg_param_local [Float] The regularization parameter for local coordinate.
|
50
|
+
# @param max_iter [Integer] The maximum number of iterations.
|
51
|
+
# @param tol [Float] The tolerance of termination criterion for finding anchors with k-means algorithm.
|
52
|
+
# @param n_anchors [Integer] The number of anchors.
|
53
|
+
# @param n_neighbors [Integer] The number of neighbors.
|
54
|
+
# @param fit_bias [Boolean] The flag indicating whether to fit bias term.
|
55
|
+
# @param bias_scale [Float] The scale parameter for bias term.
|
56
|
+
# @param random_seed [Integer] The seed value using to initialize the random generator.
|
57
|
+
def initialize(reg_param: 1.0, reg_param_local: 1e-4, max_iter: 100, tol: 1e-4,
|
58
|
+
n_anchors: 128, n_neighbors: 10, fit_bias: true, bias_scale: 1.0, random_seed: nil)
|
59
|
+
raise 'LocallyLinearSVC requires Numo::Linalg but that is not loaded' unless enable_linalg?(warning: false)
|
60
|
+
|
61
|
+
super()
|
62
|
+
@params = {
|
63
|
+
reg_param: reg_param,
|
64
|
+
reg_param_local: reg_param_local,
|
65
|
+
max_iter: max_iter,
|
66
|
+
n_anchors: n_anchors,
|
67
|
+
tol: tol,
|
68
|
+
n_neighbors: n_neighbors,
|
69
|
+
fit_bias: fit_bias,
|
70
|
+
bias_scale: bias_scale,
|
71
|
+
random_seed: random_seed || srand
|
72
|
+
}
|
73
|
+
@rng = Random.new(@params[:random_seed])
|
74
|
+
end
|
75
|
+
|
76
|
+
# Fit the model with given training data.
|
77
|
+
#
|
78
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
|
79
|
+
# @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
|
80
|
+
# @return [LocallyLinearSVC] The learned classifier itself.
|
81
|
+
def fit(x, y)
|
82
|
+
x = Rumale::Validation.check_convert_sample_array(x)
|
83
|
+
y = Rumale::Validation.check_convert_label_array(y)
|
84
|
+
Rumale::Validation.check_sample_size(x, y)
|
85
|
+
raise 'LocallyLinearSVC#fit requires Lbfgsb but that is not loaded' unless defined?(Lbfgsb)
|
86
|
+
|
87
|
+
@classes = Numo::Int32[*y.to_a.uniq.sort]
|
88
|
+
|
89
|
+
find_anchors(x)
|
90
|
+
n_samples, n_features = x.shape
|
91
|
+
@coeff = Numo::DFloat.zeros(n_samples, @params[:n_anchors])
|
92
|
+
n_samples.times do |i|
|
93
|
+
xi = x[i, true]
|
94
|
+
@coeff[i, true] = local_coordinates(xi)
|
95
|
+
end
|
96
|
+
|
97
|
+
x = expand_feature(x) if fit_bias?
|
98
|
+
|
99
|
+
if multiclass_problem?
|
100
|
+
n_classes = @classes.size
|
101
|
+
@weight_vec = Numo::DFloat.zeros(n_classes, @params[:n_anchors], n_features)
|
102
|
+
@bias_term = Numo::DFloat.zeros(n_classes, @params[:n_anchors])
|
103
|
+
n_classes.times do |n|
|
104
|
+
bin_y = Numo::Int32.cast(y.eq(@classes[n])) * 2 - 1
|
105
|
+
w, b = partial_fit(x, bin_y)
|
106
|
+
@weight_vec[n, true, true] = w
|
107
|
+
@bias_term[n, true] = b
|
108
|
+
end
|
109
|
+
else
|
110
|
+
negative_label = @classes[0]
|
111
|
+
bin_y = Numo::Int32.cast(y.ne(negative_label)) * 2 - 1
|
112
|
+
@weight_vec, @bias_term = partial_fit(x, bin_y)
|
113
|
+
end
|
114
|
+
|
115
|
+
self
|
116
|
+
end
|
117
|
+
|
118
|
+
# Calculate confidence scores for samples.
|
119
|
+
#
|
120
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to compute the scores.
|
121
|
+
# @return [Numo::DFloat] (shape: [n_samples, n_classes]) Confidence score per sample.
|
122
|
+
def decision_function(x)
|
123
|
+
x = Rumale::Validation.check_convert_sample_array(x)
|
124
|
+
n_samples = x.shape[0]
|
125
|
+
|
126
|
+
if multiclass_problem?
|
127
|
+
n_classes = @classes.size
|
128
|
+
df = Numo::DFloat.zeros(n_samples, n_classes)
|
129
|
+
n_samples.times do |i|
|
130
|
+
xi = x[i, true]
|
131
|
+
coeff = local_coordinates(xi)
|
132
|
+
n_classes.times do |j|
|
133
|
+
df[i, j] = coeff.dot(@weight_vec[j, true, true]).dot(xi) + coeff.dot(@bias_term[j, true])
|
134
|
+
end
|
135
|
+
end
|
136
|
+
else
|
137
|
+
df = Numo::DFloat.zeros(n_samples)
|
138
|
+
n_samples.times do |i|
|
139
|
+
xi = x[i, true]
|
140
|
+
coeff = local_coordinates(xi)
|
141
|
+
df[i] = coeff.dot(@weight_vec).dot(xi) + coeff.dot(@bias_term)
|
142
|
+
end
|
143
|
+
end
|
144
|
+
df
|
145
|
+
end
|
146
|
+
|
147
|
+
# Predict class labels for samples.
|
148
|
+
#
|
149
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels.
|
150
|
+
# @return [Numo::Int32] (shape: [n_samples]) Predicted class label per sample.
|
151
|
+
def predict(x)
|
152
|
+
x = Rumale::Validation.check_convert_sample_array(x)
|
153
|
+
n_samples = x.shape[0]
|
154
|
+
|
155
|
+
if multiclass_problem?
|
156
|
+
df = decision_function(x)
|
157
|
+
predicted = Array.new(n_samples) { |n| @classes[df[n, true].max_index] }
|
158
|
+
else
|
159
|
+
df = decision_function(x).ge(0.0).to_a
|
160
|
+
predicted = Array.new(n_samples) { |n| @classes[df[n]] }
|
161
|
+
end
|
162
|
+
Numo::Int32.asarray(predicted)
|
163
|
+
end
|
164
|
+
|
165
|
+
private
|
166
|
+
|
167
|
+
def partial_fit(base_x, bin_y) # rubocop:disable Metrics/AbcSize
|
168
|
+
fnc = proc do |w, x, y, coeff, reg_param|
|
169
|
+
n_anchors = coeff.shape[1]
|
170
|
+
n_samples, n_features = x.shape
|
171
|
+
w = w.reshape(n_anchors, n_features)
|
172
|
+
z = (coeff * x.dot(w.transpose)).sum(axis: 1)
|
173
|
+
t = 1 - y * z
|
174
|
+
indices = t.gt(0)
|
175
|
+
grad = reg_param * w
|
176
|
+
if indices.count.positive?
|
177
|
+
sx = x[indices, true]
|
178
|
+
sy = y[indices]
|
179
|
+
sc = coeff[indices, true]
|
180
|
+
sz = z[indices]
|
181
|
+
grad += 2.fdiv(n_samples) * (sc.transpose * (sz - sy)).dot(sx)
|
182
|
+
end
|
183
|
+
loss = 0.5 * reg_param * w.dot(w.transpose).trace + (x.class.maximum(0, t)**2).sum.fdiv(n_samples)
|
184
|
+
[loss, grad.reshape(n_anchors * n_features)]
|
185
|
+
end
|
186
|
+
|
187
|
+
n_features = base_x.shape[1]
|
188
|
+
sub_rng = @rng.dup
|
189
|
+
w_init = 2.0 * ::Rumale::Utils.rand_uniform(@params[:n_anchors] * n_features, sub_rng) - 1.0
|
190
|
+
|
191
|
+
res = Lbfgsb.minimize(
|
192
|
+
fnc: fnc, jcb: true, x_init: w_init, args: [base_x, bin_y, @coeff, @params[:reg_param]],
|
193
|
+
maxiter: @params[:max_iter], factr: @params[:tol] / Lbfgsb::DBL_EPSILON,
|
194
|
+
verbose: @params[:verbose] ? 1 : -1
|
195
|
+
)
|
196
|
+
|
197
|
+
w = res[:x].reshape(@params[:n_anchors], n_features)
|
198
|
+
|
199
|
+
if fit_bias?
|
200
|
+
[w[true, 0...-1].dup, w[true, -1].dup]
|
201
|
+
else
|
202
|
+
[w, Numo::DFloat.zeros(@params[:n_anchors])]
|
203
|
+
end
|
204
|
+
end
|
205
|
+
|
206
|
+
def local_coordinates(xi)
|
207
|
+
neighbor_ids = find_neighbors(xi)
|
208
|
+
diff = @anchors[neighbor_ids, true] - xi
|
209
|
+
gram_mat = diff.dot(diff.transpose)
|
210
|
+
gram_mat[gram_mat.diag_indices] += @params[:reg_param_local].fdiv(@params[:n_neighbors]) * gram_mat.trace
|
211
|
+
local_coeff = Numo::Linalg.solve(gram_mat, Numo::DFloat.ones(@params[:n_neighbors]))
|
212
|
+
local_coeff /= local_coeff.sum # + 1e-8
|
213
|
+
coeff = Numo::DFloat.zeros(@params[:n_anchors])
|
214
|
+
coeff[neighbor_ids] = local_coeff
|
215
|
+
coeff
|
216
|
+
end
|
217
|
+
|
218
|
+
def find_neighbors(xi)
|
219
|
+
diff = @anchors - xi
|
220
|
+
dist = (diff**2).sum(axis: 1)
|
221
|
+
dist.sort_index.to_a[0...@params[:n_neighbors]]
|
222
|
+
end
|
223
|
+
|
224
|
+
def find_anchors(x)
|
225
|
+
n_samples = x.shape[0]
|
226
|
+
sub_rng = @rng.dup
|
227
|
+
rand_id = Array.new(@params[:n_anchors]) { |_v| sub_rng.rand(0...n_samples) }
|
228
|
+
@anchors = x[rand_id, true].dup
|
229
|
+
|
230
|
+
@params[:max_iter].times do |_t|
|
231
|
+
center_ids = assign_anchors(x)
|
232
|
+
old_anchors = @anchors.dup
|
233
|
+
@params[:n_anchors].times do |n|
|
234
|
+
assigned_bits = center_ids.eq(n)
|
235
|
+
@anchors[n, true] = x[assigned_bits.where, true].mean(axis: 0) if assigned_bits.count.positive?
|
236
|
+
end
|
237
|
+
error = Numo::NMath.sqrt(((old_anchors - @anchors)**2).sum(axis: 1)).mean
|
238
|
+
break if error <= @params[:tol]
|
239
|
+
end
|
240
|
+
end
|
241
|
+
|
242
|
+
def assign_anchors(x)
|
243
|
+
distance_matrix = ::Rumale::PairwiseMetric.euclidean_distance(x, @anchors)
|
244
|
+
distance_matrix.min_index(axis: 1) - Numo::Int32[*0.step(distance_matrix.size - 1, @anchors.shape[0])]
|
245
|
+
end
|
246
|
+
|
247
|
+
def fit_bias?
|
248
|
+
@params[:fit_bias] == true
|
249
|
+
end
|
250
|
+
|
251
|
+
def expand_feature(x)
|
252
|
+
n_samples = x.shape[0]
|
253
|
+
Numo::NArray.hstack([x, Numo::DFloat.ones([n_samples, 1]) * @params[:bias_scale]])
|
254
|
+
end
|
255
|
+
|
256
|
+
def multiclass_problem?
|
257
|
+
@classes.size > 2
|
258
|
+
end
|
259
|
+
end
|
260
|
+
end
|
261
|
+
end
|
data/lib/rumale/svm/version.rb
CHANGED
@@ -2,9 +2,9 @@
|
|
2
2
|
|
3
3
|
# Rumale is a machine learning library in Ruby.
|
4
4
|
module Rumale
|
5
|
-
# This module consists of Rumale interfaces for suppor vector machine algorithms
|
5
|
+
# This module consists of Rumale interfaces for suppor vector machine algorithms using LIBSVM and LIBLINEAR.
|
6
6
|
module SVM
|
7
7
|
# The version of Rumale::SVM you are using.
|
8
|
-
VERSION = '0.
|
8
|
+
VERSION = '0.9.0'
|
9
9
|
end
|
10
10
|
end
|
data/lib/rumale/svm.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: rumale-svm
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.9.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- yoshoku
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2023-
|
11
|
+
date: 2023-11-25 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-liblinear
|
@@ -52,8 +52,8 @@ dependencies:
|
|
52
52
|
- - "~>"
|
53
53
|
- !ruby/object:Gem::Version
|
54
54
|
version: '0.24'
|
55
|
-
description: 'Rumale::SVM provides support vector machine algorithms
|
56
|
-
LIBLINEAR with Rumale interface.
|
55
|
+
description: 'Rumale::SVM provides support vector machine algorithms using LIBSVM
|
56
|
+
and LIBLINEAR with Rumale interface.
|
57
57
|
|
58
58
|
'
|
59
59
|
email:
|
@@ -69,6 +69,7 @@ files:
|
|
69
69
|
- lib/rumale/svm/linear_one_class_svm.rb
|
70
70
|
- lib/rumale/svm/linear_svc.rb
|
71
71
|
- lib/rumale/svm/linear_svr.rb
|
72
|
+
- lib/rumale/svm/locally_linear_svc.rb
|
72
73
|
- lib/rumale/svm/logistic_regression.rb
|
73
74
|
- lib/rumale/svm/nu_svc.rb
|
74
75
|
- lib/rumale/svm/nu_svr.rb
|
@@ -110,9 +111,9 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
110
111
|
- !ruby/object:Gem::Version
|
111
112
|
version: '0'
|
112
113
|
requirements: []
|
113
|
-
rubygems_version: 3.
|
114
|
+
rubygems_version: 3.4.20
|
114
115
|
signing_key:
|
115
116
|
specification_version: 4
|
116
|
-
summary: Rumale::SVM provides support vector machine algorithms
|
117
|
+
summary: Rumale::SVM provides support vector machine algorithms using LIBSVM and LIBLINEAR
|
117
118
|
with Rumale interface.
|
118
119
|
test_files: []
|