rumale-svm 0.8.0 → 0.9.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +3 -0
- data/LICENSE.txt +1 -1
- data/README.md +1 -5
- data/lib/rumale/svm/locally_linear_svc.rb +261 -0
- data/lib/rumale/svm/version.rb +2 -2
- data/lib/rumale/svm.rb +1 -0
- metadata +7 -6
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 8853a60f150df418c30832c66b04f6531a1f3fa16f96ed3cb2286ca88a553dbe
|
4
|
+
data.tar.gz: aded967a3b1f82ded39b07ef3e1f435959f26c8d8392fb212acee3fcd006366f
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 3f2ff48c445b9b5cfc804a17acb3f0c2563512eca01482c81c18fa8e662f9855e45abaede8ab5d7a8b087de84c131f2863f730d0313374b6af97480509fdd18f
|
7
|
+
data.tar.gz: 53747b28e162327cb11a9a5db9343d6baaad09e636909d8149c8ced58ba930a757b6251252968ddc7a7187c651abc56fb0ed046fde3ec6fcdc6a0d718404f420
|
data/CHANGELOG.md
CHANGED
data/LICENSE.txt
CHANGED
data/README.md
CHANGED
@@ -5,13 +5,9 @@
|
|
5
5
|
[](https://github.com/yoshoku/rumale-svm/blob/main/LICENSE.txt)
|
6
6
|
[](https://yoshoku.github.io/rumale-svm/doc/)
|
7
7
|
|
8
|
-
Rumale::SVM provides support vector machine algorithms
|
8
|
+
Rumale::SVM provides support vector machine algorithms using
|
9
9
|
[LIBSVM](https://www.csie.ntu.edu.tw/~cjlin/libsvm/) and [LIBLINEAR](https://www.csie.ntu.edu.tw/~cjlin/liblinear/)
|
10
10
|
with [Rumale](https://github.com/yoshoku/rumale) interface.
|
11
|
-
Many machine learning libraries use LIBSVM and LIBLINEAR as background libraries of support vector machine algorithms.
|
12
|
-
On the other hand, Rumale implements support vector machine algorithms based on the mini-batch stochastic gradient descent method
|
13
|
-
implemented in Ruby.
|
14
|
-
Rumale::SVM adds the functions of support vector machine similar to general machine learning libraries to Rumale.
|
15
11
|
|
16
12
|
## Installation
|
17
13
|
|
@@ -0,0 +1,261 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'rumale/base/estimator'
|
4
|
+
require 'rumale/base/classifier'
|
5
|
+
require 'rumale/pairwise_metric'
|
6
|
+
require 'rumale/utils'
|
7
|
+
require 'rumale/validation'
|
8
|
+
|
9
|
+
module Rumale
|
10
|
+
module SVM
|
11
|
+
# LocallyLinearSVC is a class that implements Locally Linear Support Vector Classifier with the squared hinge loss.
|
12
|
+
# This classifier requires Numo::Linalg (or Numo::TinyLinalg) and Lbfgsb gems,
|
13
|
+
# but they are listed in the runtime dependencies of Rumale::SVM.
|
14
|
+
# Therefore, you should install and load Numo::Linalg and Lbfgsb gems explicitly to use this classifier.
|
15
|
+
#
|
16
|
+
# @example
|
17
|
+
# require 'numo/linalg/autoloader'
|
18
|
+
# require 'lbfgsb'
|
19
|
+
# require 'rumale/svm'
|
20
|
+
#
|
21
|
+
# estimator = Rumale::SVM::LocallyLinearSVC.new(reg_param: 1.0, n_anchors: 128)
|
22
|
+
# estimator.fit(training_samples, traininig_labels)
|
23
|
+
# results = estimator.predict(testing_samples)
|
24
|
+
#
|
25
|
+
# *Reference*
|
26
|
+
# - Ladicky, L., and Torr, P H.S., "Locally Linear Support Vector Machines," Proc. ICML'11, pp. 985--992, 2011.
|
27
|
+
class LocallyLinearSVC < Rumale::Base::Estimator
|
28
|
+
include Rumale::Base::Classifier
|
29
|
+
|
30
|
+
# Return the class labels.
|
31
|
+
# @return [Numo::Int32] (size: n_classes)
|
32
|
+
attr_reader :classes
|
33
|
+
|
34
|
+
# Return the anchor vectors.
|
35
|
+
# @return [Numo::DFloat] (shape: [n_anchors, n_features])
|
36
|
+
attr_reader :anchors
|
37
|
+
|
38
|
+
# Return the weight vector.
|
39
|
+
# @return [Numo::DFloat] (shape: [n_classes, n_anchors, n_features])
|
40
|
+
attr_reader :weight_vec
|
41
|
+
|
42
|
+
# Return the bias term (a.k.a. intercept).
|
43
|
+
# @return [Numo::DFloat] (shape: [n_classes, n_anchors])
|
44
|
+
attr_reader :bias_term
|
45
|
+
|
46
|
+
# Create a new classifier with Locally Linear Support Vector Machine.
|
47
|
+
#
|
48
|
+
# @param reg_param [Float] The regularization parameter for weight vector.
|
49
|
+
# @param reg_param_local [Float] The regularization parameter for local coordinate.
|
50
|
+
# @param max_iter [Integer] The maximum number of iterations.
|
51
|
+
# @param tol [Float] The tolerance of termination criterion for finding anchors with k-means algorithm.
|
52
|
+
# @param n_anchors [Integer] The number of anchors.
|
53
|
+
# @param n_neighbors [Integer] The number of neighbors.
|
54
|
+
# @param fit_bias [Boolean] The flag indicating whether to fit bias term.
|
55
|
+
# @param bias_scale [Float] The scale parameter for bias term.
|
56
|
+
# @param random_seed [Integer] The seed value using to initialize the random generator.
|
57
|
+
def initialize(reg_param: 1.0, reg_param_local: 1e-4, max_iter: 100, tol: 1e-4,
|
58
|
+
n_anchors: 128, n_neighbors: 10, fit_bias: true, bias_scale: 1.0, random_seed: nil)
|
59
|
+
raise 'LocallyLinearSVC requires Numo::Linalg but that is not loaded' unless enable_linalg?(warning: false)
|
60
|
+
|
61
|
+
super()
|
62
|
+
@params = {
|
63
|
+
reg_param: reg_param,
|
64
|
+
reg_param_local: reg_param_local,
|
65
|
+
max_iter: max_iter,
|
66
|
+
n_anchors: n_anchors,
|
67
|
+
tol: tol,
|
68
|
+
n_neighbors: n_neighbors,
|
69
|
+
fit_bias: fit_bias,
|
70
|
+
bias_scale: bias_scale,
|
71
|
+
random_seed: random_seed || srand
|
72
|
+
}
|
73
|
+
@rng = Random.new(@params[:random_seed])
|
74
|
+
end
|
75
|
+
|
76
|
+
# Fit the model with given training data.
|
77
|
+
#
|
78
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
|
79
|
+
# @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
|
80
|
+
# @return [LocallyLinearSVC] The learned classifier itself.
|
81
|
+
def fit(x, y)
|
82
|
+
x = Rumale::Validation.check_convert_sample_array(x)
|
83
|
+
y = Rumale::Validation.check_convert_label_array(y)
|
84
|
+
Rumale::Validation.check_sample_size(x, y)
|
85
|
+
raise 'LocallyLinearSVC#fit requires Lbfgsb but that is not loaded' unless defined?(Lbfgsb)
|
86
|
+
|
87
|
+
@classes = Numo::Int32[*y.to_a.uniq.sort]
|
88
|
+
|
89
|
+
find_anchors(x)
|
90
|
+
n_samples, n_features = x.shape
|
91
|
+
@coeff = Numo::DFloat.zeros(n_samples, @params[:n_anchors])
|
92
|
+
n_samples.times do |i|
|
93
|
+
xi = x[i, true]
|
94
|
+
@coeff[i, true] = local_coordinates(xi)
|
95
|
+
end
|
96
|
+
|
97
|
+
x = expand_feature(x) if fit_bias?
|
98
|
+
|
99
|
+
if multiclass_problem?
|
100
|
+
n_classes = @classes.size
|
101
|
+
@weight_vec = Numo::DFloat.zeros(n_classes, @params[:n_anchors], n_features)
|
102
|
+
@bias_term = Numo::DFloat.zeros(n_classes, @params[:n_anchors])
|
103
|
+
n_classes.times do |n|
|
104
|
+
bin_y = Numo::Int32.cast(y.eq(@classes[n])) * 2 - 1
|
105
|
+
w, b = partial_fit(x, bin_y)
|
106
|
+
@weight_vec[n, true, true] = w
|
107
|
+
@bias_term[n, true] = b
|
108
|
+
end
|
109
|
+
else
|
110
|
+
negative_label = @classes[0]
|
111
|
+
bin_y = Numo::Int32.cast(y.ne(negative_label)) * 2 - 1
|
112
|
+
@weight_vec, @bias_term = partial_fit(x, bin_y)
|
113
|
+
end
|
114
|
+
|
115
|
+
self
|
116
|
+
end
|
117
|
+
|
118
|
+
# Calculate confidence scores for samples.
|
119
|
+
#
|
120
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to compute the scores.
|
121
|
+
# @return [Numo::DFloat] (shape: [n_samples, n_classes]) Confidence score per sample.
|
122
|
+
def decision_function(x)
|
123
|
+
x = Rumale::Validation.check_convert_sample_array(x)
|
124
|
+
n_samples = x.shape[0]
|
125
|
+
|
126
|
+
if multiclass_problem?
|
127
|
+
n_classes = @classes.size
|
128
|
+
df = Numo::DFloat.zeros(n_samples, n_classes)
|
129
|
+
n_samples.times do |i|
|
130
|
+
xi = x[i, true]
|
131
|
+
coeff = local_coordinates(xi)
|
132
|
+
n_classes.times do |j|
|
133
|
+
df[i, j] = coeff.dot(@weight_vec[j, true, true]).dot(xi) + coeff.dot(@bias_term[j, true])
|
134
|
+
end
|
135
|
+
end
|
136
|
+
else
|
137
|
+
df = Numo::DFloat.zeros(n_samples)
|
138
|
+
n_samples.times do |i|
|
139
|
+
xi = x[i, true]
|
140
|
+
coeff = local_coordinates(xi)
|
141
|
+
df[i] = coeff.dot(@weight_vec).dot(xi) + coeff.dot(@bias_term)
|
142
|
+
end
|
143
|
+
end
|
144
|
+
df
|
145
|
+
end
|
146
|
+
|
147
|
+
# Predict class labels for samples.
|
148
|
+
#
|
149
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels.
|
150
|
+
# @return [Numo::Int32] (shape: [n_samples]) Predicted class label per sample.
|
151
|
+
def predict(x)
|
152
|
+
x = Rumale::Validation.check_convert_sample_array(x)
|
153
|
+
n_samples = x.shape[0]
|
154
|
+
|
155
|
+
if multiclass_problem?
|
156
|
+
df = decision_function(x)
|
157
|
+
predicted = Array.new(n_samples) { |n| @classes[df[n, true].max_index] }
|
158
|
+
else
|
159
|
+
df = decision_function(x).ge(0.0).to_a
|
160
|
+
predicted = Array.new(n_samples) { |n| @classes[df[n]] }
|
161
|
+
end
|
162
|
+
Numo::Int32.asarray(predicted)
|
163
|
+
end
|
164
|
+
|
165
|
+
private
|
166
|
+
|
167
|
+
def partial_fit(base_x, bin_y) # rubocop:disable Metrics/AbcSize
|
168
|
+
fnc = proc do |w, x, y, coeff, reg_param|
|
169
|
+
n_anchors = coeff.shape[1]
|
170
|
+
n_samples, n_features = x.shape
|
171
|
+
w = w.reshape(n_anchors, n_features)
|
172
|
+
z = (coeff * x.dot(w.transpose)).sum(axis: 1)
|
173
|
+
t = 1 - y * z
|
174
|
+
indices = t.gt(0)
|
175
|
+
grad = reg_param * w
|
176
|
+
if indices.count.positive?
|
177
|
+
sx = x[indices, true]
|
178
|
+
sy = y[indices]
|
179
|
+
sc = coeff[indices, true]
|
180
|
+
sz = z[indices]
|
181
|
+
grad += 2.fdiv(n_samples) * (sc.transpose * (sz - sy)).dot(sx)
|
182
|
+
end
|
183
|
+
loss = 0.5 * reg_param * w.dot(w.transpose).trace + (x.class.maximum(0, t)**2).sum.fdiv(n_samples)
|
184
|
+
[loss, grad.reshape(n_anchors * n_features)]
|
185
|
+
end
|
186
|
+
|
187
|
+
n_features = base_x.shape[1]
|
188
|
+
sub_rng = @rng.dup
|
189
|
+
w_init = 2.0 * ::Rumale::Utils.rand_uniform(@params[:n_anchors] * n_features, sub_rng) - 1.0
|
190
|
+
|
191
|
+
res = Lbfgsb.minimize(
|
192
|
+
fnc: fnc, jcb: true, x_init: w_init, args: [base_x, bin_y, @coeff, @params[:reg_param]],
|
193
|
+
maxiter: @params[:max_iter], factr: @params[:tol] / Lbfgsb::DBL_EPSILON,
|
194
|
+
verbose: @params[:verbose] ? 1 : -1
|
195
|
+
)
|
196
|
+
|
197
|
+
w = res[:x].reshape(@params[:n_anchors], n_features)
|
198
|
+
|
199
|
+
if fit_bias?
|
200
|
+
[w[true, 0...-1].dup, w[true, -1].dup]
|
201
|
+
else
|
202
|
+
[w, Numo::DFloat.zeros(@params[:n_anchors])]
|
203
|
+
end
|
204
|
+
end
|
205
|
+
|
206
|
+
def local_coordinates(xi)
|
207
|
+
neighbor_ids = find_neighbors(xi)
|
208
|
+
diff = @anchors[neighbor_ids, true] - xi
|
209
|
+
gram_mat = diff.dot(diff.transpose)
|
210
|
+
gram_mat[gram_mat.diag_indices] += @params[:reg_param_local].fdiv(@params[:n_neighbors]) * gram_mat.trace
|
211
|
+
local_coeff = Numo::Linalg.solve(gram_mat, Numo::DFloat.ones(@params[:n_neighbors]))
|
212
|
+
local_coeff /= local_coeff.sum # + 1e-8
|
213
|
+
coeff = Numo::DFloat.zeros(@params[:n_anchors])
|
214
|
+
coeff[neighbor_ids] = local_coeff
|
215
|
+
coeff
|
216
|
+
end
|
217
|
+
|
218
|
+
def find_neighbors(xi)
|
219
|
+
diff = @anchors - xi
|
220
|
+
dist = (diff**2).sum(axis: 1)
|
221
|
+
dist.sort_index.to_a[0...@params[:n_neighbors]]
|
222
|
+
end
|
223
|
+
|
224
|
+
def find_anchors(x)
|
225
|
+
n_samples = x.shape[0]
|
226
|
+
sub_rng = @rng.dup
|
227
|
+
rand_id = Array.new(@params[:n_anchors]) { |_v| sub_rng.rand(0...n_samples) }
|
228
|
+
@anchors = x[rand_id, true].dup
|
229
|
+
|
230
|
+
@params[:max_iter].times do |_t|
|
231
|
+
center_ids = assign_anchors(x)
|
232
|
+
old_anchors = @anchors.dup
|
233
|
+
@params[:n_anchors].times do |n|
|
234
|
+
assigned_bits = center_ids.eq(n)
|
235
|
+
@anchors[n, true] = x[assigned_bits.where, true].mean(axis: 0) if assigned_bits.count.positive?
|
236
|
+
end
|
237
|
+
error = Numo::NMath.sqrt(((old_anchors - @anchors)**2).sum(axis: 1)).mean
|
238
|
+
break if error <= @params[:tol]
|
239
|
+
end
|
240
|
+
end
|
241
|
+
|
242
|
+
def assign_anchors(x)
|
243
|
+
distance_matrix = ::Rumale::PairwiseMetric.euclidean_distance(x, @anchors)
|
244
|
+
distance_matrix.min_index(axis: 1) - Numo::Int32[*0.step(distance_matrix.size - 1, @anchors.shape[0])]
|
245
|
+
end
|
246
|
+
|
247
|
+
def fit_bias?
|
248
|
+
@params[:fit_bias] == true
|
249
|
+
end
|
250
|
+
|
251
|
+
def expand_feature(x)
|
252
|
+
n_samples = x.shape[0]
|
253
|
+
Numo::NArray.hstack([x, Numo::DFloat.ones([n_samples, 1]) * @params[:bias_scale]])
|
254
|
+
end
|
255
|
+
|
256
|
+
def multiclass_problem?
|
257
|
+
@classes.size > 2
|
258
|
+
end
|
259
|
+
end
|
260
|
+
end
|
261
|
+
end
|
data/lib/rumale/svm/version.rb
CHANGED
@@ -2,9 +2,9 @@
|
|
2
2
|
|
3
3
|
# Rumale is a machine learning library in Ruby.
|
4
4
|
module Rumale
|
5
|
-
# This module consists of Rumale interfaces for suppor vector machine algorithms
|
5
|
+
# This module consists of Rumale interfaces for suppor vector machine algorithms using LIBSVM and LIBLINEAR.
|
6
6
|
module SVM
|
7
7
|
# The version of Rumale::SVM you are using.
|
8
|
-
VERSION = '0.
|
8
|
+
VERSION = '0.9.0'
|
9
9
|
end
|
10
10
|
end
|
data/lib/rumale/svm.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: rumale-svm
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.9.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- yoshoku
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2023-
|
11
|
+
date: 2023-11-25 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-liblinear
|
@@ -52,8 +52,8 @@ dependencies:
|
|
52
52
|
- - "~>"
|
53
53
|
- !ruby/object:Gem::Version
|
54
54
|
version: '0.24'
|
55
|
-
description: 'Rumale::SVM provides support vector machine algorithms
|
56
|
-
LIBLINEAR with Rumale interface.
|
55
|
+
description: 'Rumale::SVM provides support vector machine algorithms using LIBSVM
|
56
|
+
and LIBLINEAR with Rumale interface.
|
57
57
|
|
58
58
|
'
|
59
59
|
email:
|
@@ -69,6 +69,7 @@ files:
|
|
69
69
|
- lib/rumale/svm/linear_one_class_svm.rb
|
70
70
|
- lib/rumale/svm/linear_svc.rb
|
71
71
|
- lib/rumale/svm/linear_svr.rb
|
72
|
+
- lib/rumale/svm/locally_linear_svc.rb
|
72
73
|
- lib/rumale/svm/logistic_regression.rb
|
73
74
|
- lib/rumale/svm/nu_svc.rb
|
74
75
|
- lib/rumale/svm/nu_svr.rb
|
@@ -110,9 +111,9 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
110
111
|
- !ruby/object:Gem::Version
|
111
112
|
version: '0'
|
112
113
|
requirements: []
|
113
|
-
rubygems_version: 3.
|
114
|
+
rubygems_version: 3.4.20
|
114
115
|
signing_key:
|
115
116
|
specification_version: 4
|
116
|
-
summary: Rumale::SVM provides support vector machine algorithms
|
117
|
+
summary: Rumale::SVM provides support vector machine algorithms using LIBSVM and LIBLINEAR
|
117
118
|
with Rumale interface.
|
118
119
|
test_files: []
|