rumale-svm 0.8.0 → 0.9.0

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 5611ab7bee24d673f0803b5679532326389a9a485dd5beafeb9d448ab5538363
4
- data.tar.gz: 033024146647d6ba5e4acf08eb876dc35e0fe25f141555891dec77e84cd6572b
3
+ metadata.gz: 8853a60f150df418c30832c66b04f6531a1f3fa16f96ed3cb2286ca88a553dbe
4
+ data.tar.gz: aded967a3b1f82ded39b07ef3e1f435959f26c8d8392fb212acee3fcd006366f
5
5
  SHA512:
6
- metadata.gz: b70772fd9480c0de3f28d477bdf35ccae3260f9a9d1156423cac8654caebbc1891a9182e1ee3a4b3b7add8ce885ac06c2e57e8d3ca46a71e9d7de99d46d1c33d
7
- data.tar.gz: a38ccc2f952bb1e3fe5c4906e9225bb2d6c5c938b93aa6d08a8b66a2b20f4e94b0d96fe08c86f489050051eb149cdf68b74826b6a5f7439fbab5ee47de2b67e6
6
+ metadata.gz: 3f2ff48c445b9b5cfc804a17acb3f0c2563512eca01482c81c18fa8e662f9855e45abaede8ab5d7a8b087de84c131f2863f730d0313374b6af97480509fdd18f
7
+ data.tar.gz: 53747b28e162327cb11a9a5db9343d6baaad09e636909d8149c8ced58ba930a757b6251252968ddc7a7187c651abc56fb0ed046fde3ec6fcdc6a0d718404f420
data/CHANGELOG.md CHANGED
@@ -1,3 +1,6 @@
1
+ # 0.9.0
2
+ - Add Rumale::SVM::LocallyLinearSVC that is classifier with locally linear support vector machine.
3
+
1
4
  # 0.8.0
2
5
  - Refactor to support the new Rumale API.
3
6
 
data/LICENSE.txt CHANGED
@@ -1,4 +1,4 @@
1
- Copyright (c) 2019-2022 Atsushi Tatsuma
1
+ Copyright (c) 2019-2023 Atsushi Tatsuma
2
2
  All rights reserved.
3
3
 
4
4
  Redistribution and use in source and binary forms, with or without
data/README.md CHANGED
@@ -5,13 +5,9 @@
5
5
  [![BSD 3-Clause License](https://img.shields.io/badge/License-BSD%203--Clause-orange.svg)](https://github.com/yoshoku/rumale-svm/blob/main/LICENSE.txt)
6
6
  [![Documentation](https://img.shields.io/badge/api-reference-blue.svg)](https://yoshoku.github.io/rumale-svm/doc/)
7
7
 
8
- Rumale::SVM provides support vector machine algorithms in
8
+ Rumale::SVM provides support vector machine algorithms using
9
9
  [LIBSVM](https://www.csie.ntu.edu.tw/~cjlin/libsvm/) and [LIBLINEAR](https://www.csie.ntu.edu.tw/~cjlin/liblinear/)
10
10
  with [Rumale](https://github.com/yoshoku/rumale) interface.
11
- Many machine learning libraries use LIBSVM and LIBLINEAR as background libraries of support vector machine algorithms.
12
- On the other hand, Rumale implements support vector machine algorithms based on the mini-batch stochastic gradient descent method
13
- implemented in Ruby.
14
- Rumale::SVM adds the functions of support vector machine similar to general machine learning libraries to Rumale.
15
11
 
16
12
  ## Installation
17
13
 
@@ -0,0 +1,261 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/estimator'
4
+ require 'rumale/base/classifier'
5
+ require 'rumale/pairwise_metric'
6
+ require 'rumale/utils'
7
+ require 'rumale/validation'
8
+
9
+ module Rumale
10
+ module SVM
11
+ # LocallyLinearSVC is a class that implements Locally Linear Support Vector Classifier with the squared hinge loss.
12
+ # This classifier requires Numo::Linalg (or Numo::TinyLinalg) and Lbfgsb gems,
13
+ # but they are listed in the runtime dependencies of Rumale::SVM.
14
+ # Therefore, you should install and load Numo::Linalg and Lbfgsb gems explicitly to use this classifier.
15
+ #
16
+ # @example
17
+ # require 'numo/linalg/autoloader'
18
+ # require 'lbfgsb'
19
+ # require 'rumale/svm'
20
+ #
21
+ # estimator = Rumale::SVM::LocallyLinearSVC.new(reg_param: 1.0, n_anchors: 128)
22
+ # estimator.fit(training_samples, traininig_labels)
23
+ # results = estimator.predict(testing_samples)
24
+ #
25
+ # *Reference*
26
+ # - Ladicky, L., and Torr, P H.S., "Locally Linear Support Vector Machines," Proc. ICML'11, pp. 985--992, 2011.
27
+ class LocallyLinearSVC < Rumale::Base::Estimator
28
+ include Rumale::Base::Classifier
29
+
30
+ # Return the class labels.
31
+ # @return [Numo::Int32] (size: n_classes)
32
+ attr_reader :classes
33
+
34
+ # Return the anchor vectors.
35
+ # @return [Numo::DFloat] (shape: [n_anchors, n_features])
36
+ attr_reader :anchors
37
+
38
+ # Return the weight vector.
39
+ # @return [Numo::DFloat] (shape: [n_classes, n_anchors, n_features])
40
+ attr_reader :weight_vec
41
+
42
+ # Return the bias term (a.k.a. intercept).
43
+ # @return [Numo::DFloat] (shape: [n_classes, n_anchors])
44
+ attr_reader :bias_term
45
+
46
+ # Create a new classifier with Locally Linear Support Vector Machine.
47
+ #
48
+ # @param reg_param [Float] The regularization parameter for weight vector.
49
+ # @param reg_param_local [Float] The regularization parameter for local coordinate.
50
+ # @param max_iter [Integer] The maximum number of iterations.
51
+ # @param tol [Float] The tolerance of termination criterion for finding anchors with k-means algorithm.
52
+ # @param n_anchors [Integer] The number of anchors.
53
+ # @param n_neighbors [Integer] The number of neighbors.
54
+ # @param fit_bias [Boolean] The flag indicating whether to fit bias term.
55
+ # @param bias_scale [Float] The scale parameter for bias term.
56
+ # @param random_seed [Integer] The seed value using to initialize the random generator.
57
+ def initialize(reg_param: 1.0, reg_param_local: 1e-4, max_iter: 100, tol: 1e-4,
58
+ n_anchors: 128, n_neighbors: 10, fit_bias: true, bias_scale: 1.0, random_seed: nil)
59
+ raise 'LocallyLinearSVC requires Numo::Linalg but that is not loaded' unless enable_linalg?(warning: false)
60
+
61
+ super()
62
+ @params = {
63
+ reg_param: reg_param,
64
+ reg_param_local: reg_param_local,
65
+ max_iter: max_iter,
66
+ n_anchors: n_anchors,
67
+ tol: tol,
68
+ n_neighbors: n_neighbors,
69
+ fit_bias: fit_bias,
70
+ bias_scale: bias_scale,
71
+ random_seed: random_seed || srand
72
+ }
73
+ @rng = Random.new(@params[:random_seed])
74
+ end
75
+
76
+ # Fit the model with given training data.
77
+ #
78
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
79
+ # @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
80
+ # @return [LocallyLinearSVC] The learned classifier itself.
81
+ def fit(x, y)
82
+ x = Rumale::Validation.check_convert_sample_array(x)
83
+ y = Rumale::Validation.check_convert_label_array(y)
84
+ Rumale::Validation.check_sample_size(x, y)
85
+ raise 'LocallyLinearSVC#fit requires Lbfgsb but that is not loaded' unless defined?(Lbfgsb)
86
+
87
+ @classes = Numo::Int32[*y.to_a.uniq.sort]
88
+
89
+ find_anchors(x)
90
+ n_samples, n_features = x.shape
91
+ @coeff = Numo::DFloat.zeros(n_samples, @params[:n_anchors])
92
+ n_samples.times do |i|
93
+ xi = x[i, true]
94
+ @coeff[i, true] = local_coordinates(xi)
95
+ end
96
+
97
+ x = expand_feature(x) if fit_bias?
98
+
99
+ if multiclass_problem?
100
+ n_classes = @classes.size
101
+ @weight_vec = Numo::DFloat.zeros(n_classes, @params[:n_anchors], n_features)
102
+ @bias_term = Numo::DFloat.zeros(n_classes, @params[:n_anchors])
103
+ n_classes.times do |n|
104
+ bin_y = Numo::Int32.cast(y.eq(@classes[n])) * 2 - 1
105
+ w, b = partial_fit(x, bin_y)
106
+ @weight_vec[n, true, true] = w
107
+ @bias_term[n, true] = b
108
+ end
109
+ else
110
+ negative_label = @classes[0]
111
+ bin_y = Numo::Int32.cast(y.ne(negative_label)) * 2 - 1
112
+ @weight_vec, @bias_term = partial_fit(x, bin_y)
113
+ end
114
+
115
+ self
116
+ end
117
+
118
+ # Calculate confidence scores for samples.
119
+ #
120
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to compute the scores.
121
+ # @return [Numo::DFloat] (shape: [n_samples, n_classes]) Confidence score per sample.
122
+ def decision_function(x)
123
+ x = Rumale::Validation.check_convert_sample_array(x)
124
+ n_samples = x.shape[0]
125
+
126
+ if multiclass_problem?
127
+ n_classes = @classes.size
128
+ df = Numo::DFloat.zeros(n_samples, n_classes)
129
+ n_samples.times do |i|
130
+ xi = x[i, true]
131
+ coeff = local_coordinates(xi)
132
+ n_classes.times do |j|
133
+ df[i, j] = coeff.dot(@weight_vec[j, true, true]).dot(xi) + coeff.dot(@bias_term[j, true])
134
+ end
135
+ end
136
+ else
137
+ df = Numo::DFloat.zeros(n_samples)
138
+ n_samples.times do |i|
139
+ xi = x[i, true]
140
+ coeff = local_coordinates(xi)
141
+ df[i] = coeff.dot(@weight_vec).dot(xi) + coeff.dot(@bias_term)
142
+ end
143
+ end
144
+ df
145
+ end
146
+
147
+ # Predict class labels for samples.
148
+ #
149
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels.
150
+ # @return [Numo::Int32] (shape: [n_samples]) Predicted class label per sample.
151
+ def predict(x)
152
+ x = Rumale::Validation.check_convert_sample_array(x)
153
+ n_samples = x.shape[0]
154
+
155
+ if multiclass_problem?
156
+ df = decision_function(x)
157
+ predicted = Array.new(n_samples) { |n| @classes[df[n, true].max_index] }
158
+ else
159
+ df = decision_function(x).ge(0.0).to_a
160
+ predicted = Array.new(n_samples) { |n| @classes[df[n]] }
161
+ end
162
+ Numo::Int32.asarray(predicted)
163
+ end
164
+
165
+ private
166
+
167
+ def partial_fit(base_x, bin_y) # rubocop:disable Metrics/AbcSize
168
+ fnc = proc do |w, x, y, coeff, reg_param|
169
+ n_anchors = coeff.shape[1]
170
+ n_samples, n_features = x.shape
171
+ w = w.reshape(n_anchors, n_features)
172
+ z = (coeff * x.dot(w.transpose)).sum(axis: 1)
173
+ t = 1 - y * z
174
+ indices = t.gt(0)
175
+ grad = reg_param * w
176
+ if indices.count.positive?
177
+ sx = x[indices, true]
178
+ sy = y[indices]
179
+ sc = coeff[indices, true]
180
+ sz = z[indices]
181
+ grad += 2.fdiv(n_samples) * (sc.transpose * (sz - sy)).dot(sx)
182
+ end
183
+ loss = 0.5 * reg_param * w.dot(w.transpose).trace + (x.class.maximum(0, t)**2).sum.fdiv(n_samples)
184
+ [loss, grad.reshape(n_anchors * n_features)]
185
+ end
186
+
187
+ n_features = base_x.shape[1]
188
+ sub_rng = @rng.dup
189
+ w_init = 2.0 * ::Rumale::Utils.rand_uniform(@params[:n_anchors] * n_features, sub_rng) - 1.0
190
+
191
+ res = Lbfgsb.minimize(
192
+ fnc: fnc, jcb: true, x_init: w_init, args: [base_x, bin_y, @coeff, @params[:reg_param]],
193
+ maxiter: @params[:max_iter], factr: @params[:tol] / Lbfgsb::DBL_EPSILON,
194
+ verbose: @params[:verbose] ? 1 : -1
195
+ )
196
+
197
+ w = res[:x].reshape(@params[:n_anchors], n_features)
198
+
199
+ if fit_bias?
200
+ [w[true, 0...-1].dup, w[true, -1].dup]
201
+ else
202
+ [w, Numo::DFloat.zeros(@params[:n_anchors])]
203
+ end
204
+ end
205
+
206
+ def local_coordinates(xi)
207
+ neighbor_ids = find_neighbors(xi)
208
+ diff = @anchors[neighbor_ids, true] - xi
209
+ gram_mat = diff.dot(diff.transpose)
210
+ gram_mat[gram_mat.diag_indices] += @params[:reg_param_local].fdiv(@params[:n_neighbors]) * gram_mat.trace
211
+ local_coeff = Numo::Linalg.solve(gram_mat, Numo::DFloat.ones(@params[:n_neighbors]))
212
+ local_coeff /= local_coeff.sum # + 1e-8
213
+ coeff = Numo::DFloat.zeros(@params[:n_anchors])
214
+ coeff[neighbor_ids] = local_coeff
215
+ coeff
216
+ end
217
+
218
+ def find_neighbors(xi)
219
+ diff = @anchors - xi
220
+ dist = (diff**2).sum(axis: 1)
221
+ dist.sort_index.to_a[0...@params[:n_neighbors]]
222
+ end
223
+
224
+ def find_anchors(x)
225
+ n_samples = x.shape[0]
226
+ sub_rng = @rng.dup
227
+ rand_id = Array.new(@params[:n_anchors]) { |_v| sub_rng.rand(0...n_samples) }
228
+ @anchors = x[rand_id, true].dup
229
+
230
+ @params[:max_iter].times do |_t|
231
+ center_ids = assign_anchors(x)
232
+ old_anchors = @anchors.dup
233
+ @params[:n_anchors].times do |n|
234
+ assigned_bits = center_ids.eq(n)
235
+ @anchors[n, true] = x[assigned_bits.where, true].mean(axis: 0) if assigned_bits.count.positive?
236
+ end
237
+ error = Numo::NMath.sqrt(((old_anchors - @anchors)**2).sum(axis: 1)).mean
238
+ break if error <= @params[:tol]
239
+ end
240
+ end
241
+
242
+ def assign_anchors(x)
243
+ distance_matrix = ::Rumale::PairwiseMetric.euclidean_distance(x, @anchors)
244
+ distance_matrix.min_index(axis: 1) - Numo::Int32[*0.step(distance_matrix.size - 1, @anchors.shape[0])]
245
+ end
246
+
247
+ def fit_bias?
248
+ @params[:fit_bias] == true
249
+ end
250
+
251
+ def expand_feature(x)
252
+ n_samples = x.shape[0]
253
+ Numo::NArray.hstack([x, Numo::DFloat.ones([n_samples, 1]) * @params[:bias_scale]])
254
+ end
255
+
256
+ def multiclass_problem?
257
+ @classes.size > 2
258
+ end
259
+ end
260
+ end
261
+ end
@@ -2,9 +2,9 @@
2
2
 
3
3
  # Rumale is a machine learning library in Ruby.
4
4
  module Rumale
5
- # This module consists of Rumale interfaces for suppor vector machine algorithms with LIBSVM and LIBLINEAR.
5
+ # This module consists of Rumale interfaces for suppor vector machine algorithms using LIBSVM and LIBLINEAR.
6
6
  module SVM
7
7
  # The version of Rumale::SVM you are using.
8
- VERSION = '0.8.0'
8
+ VERSION = '0.9.0'
9
9
  end
10
10
  end
data/lib/rumale/svm.rb CHANGED
@@ -10,3 +10,4 @@ require 'rumale/svm/linear_svc'
10
10
  require 'rumale/svm/linear_svr'
11
11
  require 'rumale/svm/logistic_regression'
12
12
  require 'rumale/svm/linear_one_class_svm'
13
+ require 'rumale/svm/locally_linear_svc'
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: rumale-svm
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.8.0
4
+ version: 0.9.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2023-01-01 00:00:00.000000000 Z
11
+ date: 2023-11-25 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-liblinear
@@ -52,8 +52,8 @@ dependencies:
52
52
  - - "~>"
53
53
  - !ruby/object:Gem::Version
54
54
  version: '0.24'
55
- description: 'Rumale::SVM provides support vector machine algorithms of LIBSVM and
56
- LIBLINEAR with Rumale interface.
55
+ description: 'Rumale::SVM provides support vector machine algorithms using LIBSVM
56
+ and LIBLINEAR with Rumale interface.
57
57
 
58
58
  '
59
59
  email:
@@ -69,6 +69,7 @@ files:
69
69
  - lib/rumale/svm/linear_one_class_svm.rb
70
70
  - lib/rumale/svm/linear_svc.rb
71
71
  - lib/rumale/svm/linear_svr.rb
72
+ - lib/rumale/svm/locally_linear_svc.rb
72
73
  - lib/rumale/svm/logistic_regression.rb
73
74
  - lib/rumale/svm/nu_svc.rb
74
75
  - lib/rumale/svm/nu_svr.rb
@@ -110,9 +111,9 @@ required_rubygems_version: !ruby/object:Gem::Requirement
110
111
  - !ruby/object:Gem::Version
111
112
  version: '0'
112
113
  requirements: []
113
- rubygems_version: 3.3.26
114
+ rubygems_version: 3.4.20
114
115
  signing_key:
115
116
  specification_version: 4
116
- summary: Rumale::SVM provides support vector machine algorithms of LIBSVM and LIBLINEAR
117
+ summary: Rumale::SVM provides support vector machine algorithms using LIBSVM and LIBLINEAR
117
118
  with Rumale interface.
118
119
  test_files: []