rumale-svm 0.10.0 → 0.12.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/LICENSE.txt +1 -1
- data/lib/rumale/svm/clustered_svc.rb +171 -0
- data/lib/rumale/svm/version.rb +1 -1
- data/lib/rumale/svm.rb +1 -0
- data/sig/rumale/svm/clustered_svc.rbs +29 -0
- metadata +7 -8
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: eb43a42ed2be607317dde53e4f81788309aa27e28d1c169419e6e0d0cce073af
|
4
|
+
data.tar.gz: fe1ba15cd33b364a73b249e83c1b4a6937f50b5d7424db49957c2b4ca87c9267
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 4aad8ec437d6522fbd3d90167a80a463d11b19ff60ee00d90d4a4f8810ee25a615592f2d5caeaf22fe9504aad9189f87c332626c27ef1f8de0f70c15b1908b51
|
7
|
+
data.tar.gz: d2b12aab3a83ba1ad71d89e49499b007f3054dbc0e49402b4a00ffc238450ade1701edb46bbd7d276dd231cd471831d2db62c210e4b76a096dc02a072c202796
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,9 @@
|
|
1
|
+
# [[0.12.0](https://github.com/yoshoku/rumale-svm/compare/v0.11.0...v0.12.0)]
|
2
|
+
- Fix the version specification of rumale-core gem.
|
3
|
+
|
4
|
+
# [[0.11.0](https://github.com/yoshoku/rumale-svm/compare/v0.10.0...v0.11.0)]
|
5
|
+
- Add Rumale::SVM::ClusteredSVC that is classifier with clustered support vector machine.
|
6
|
+
|
1
7
|
# 0.10.0
|
2
8
|
- Add Rumale::SVM::RandomRecursiveSVC that is classifier with random recursive support vector machine.
|
3
9
|
- Add type declaration files for RandomRecursiveSVC and LocallyLinearSVC.
|
data/LICENSE.txt
CHANGED
@@ -0,0 +1,171 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'rumale/base/estimator'
|
4
|
+
require 'rumale/base/classifier'
|
5
|
+
require 'rumale/pairwise_metric'
|
6
|
+
require 'rumale/validation'
|
7
|
+
require 'rumale/svm/linear_svc'
|
8
|
+
|
9
|
+
module Rumale
|
10
|
+
module SVM
|
11
|
+
# ClusteredSVC is a class that implements Clustered Support Vector Classifier.
|
12
|
+
#
|
13
|
+
# @example
|
14
|
+
# require 'rumale/svm'
|
15
|
+
#
|
16
|
+
# estimator = Rumale::SVM::ClusteredSVC.new(n_clusters: 16, reg_param_global: 1.0, random_seed: 1)
|
17
|
+
# estimator.fit(training_samples, training_labels)
|
18
|
+
# results = estimator.predict(testing_samples)
|
19
|
+
#
|
20
|
+
# *Reference*
|
21
|
+
# - Gu, Q., and Han, J., "Clustered Support Vector Machines," In Proc. AISTATS'13, pp. 307--315, 2013.
|
22
|
+
class ClusteredSVC < Rumale::Base::Estimator
|
23
|
+
include Rumale::Base::Classifier
|
24
|
+
|
25
|
+
# Return the classifier.
|
26
|
+
# @return [LinearSVC]
|
27
|
+
attr_reader :model
|
28
|
+
|
29
|
+
# Return the centroids.
|
30
|
+
# @return [Numo::DFloat] (shape: [n_clusters, n_features])
|
31
|
+
attr_accessor :cluster_centers
|
32
|
+
|
33
|
+
# Create a new classifier with Random Recursive Support Vector Machine.
|
34
|
+
#
|
35
|
+
# @param n_clusters [Integer] The number of clusters.
|
36
|
+
# @param reg_param_global [Float] The regularization parameter for global reference vector.
|
37
|
+
# @param max_iter_kmeans [Integer] The maximum number of iterations for k-means clustering.
|
38
|
+
# @param tol_kmeans [Float] The tolerance of termination criterion for k-means clustering.
|
39
|
+
# @param penalty [String] The type of norm used in the penalization ('l2' or 'l1').
|
40
|
+
# @param loss [String] The type of loss function ('squared_hinge' or 'hinge').
|
41
|
+
# This parameter is ignored if penalty = 'l1'.
|
42
|
+
# @param dual [Boolean] The flag indicating whether to solve dual optimization problem.
|
43
|
+
# When n_samples > n_features, dual = false is more preferable.
|
44
|
+
# This parameter is ignored if loss = 'hinge'.
|
45
|
+
# @param reg_param [Float] The regularization parameter.
|
46
|
+
# @param fit_bias [Boolean] The flag indicating whether to fit the bias term.
|
47
|
+
# @param bias_scale [Float] The scale of the bias term.
|
48
|
+
# This parameter is ignored if fit_bias = false.
|
49
|
+
# @param tol [Float] The tolerance of termination criterion.
|
50
|
+
# @param verbose [Boolean] The flag indicating whether to output learning process message
|
51
|
+
# @param random_seed [Integer/Nil] The seed value using to initialize the random generator.
|
52
|
+
def initialize(n_clusters: 8, reg_param_global: 1.0, max_iter_kmeans: 100, tol_kmeans: 1e-6, # rubocop:disable Metrics/ParameterLists
|
53
|
+
penalty: 'l2', loss: 'squared_hinge', dual: true, reg_param: 1.0,
|
54
|
+
fit_bias: true, bias_scale: 1.0, tol: 1e-3, verbose: false, random_seed: nil)
|
55
|
+
super()
|
56
|
+
@params = {
|
57
|
+
n_clusters: n_clusters,
|
58
|
+
reg_param_global: reg_param_global,
|
59
|
+
max_iter_kmeans: max_iter_kmeans,
|
60
|
+
tol_kmeans: tol_kmeans,
|
61
|
+
penalty: penalty == 'l1' ? 'l1' : 'l2',
|
62
|
+
loss: loss == 'hinge' ? 'hinge' : 'squared_hinge',
|
63
|
+
dual: dual,
|
64
|
+
reg_param: reg_param.to_f,
|
65
|
+
fit_bias: fit_bias,
|
66
|
+
bias_scale: bias_scale.to_f,
|
67
|
+
tol: tol.to_f,
|
68
|
+
verbose: verbose,
|
69
|
+
random_seed: random_seed || Random.rand(4_294_967_295)
|
70
|
+
}
|
71
|
+
@rng = Random.new(@params[:random_seed])
|
72
|
+
@cluster_centers = nil
|
73
|
+
end
|
74
|
+
|
75
|
+
# Fit the model with given training data.
|
76
|
+
#
|
77
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
|
78
|
+
# @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
|
79
|
+
# @return [ClusteredSVC] The learned classifier itself.
|
80
|
+
def fit(x, y)
|
81
|
+
z = transform(x)
|
82
|
+
@model = LinearSVC.new(**linear_svc_params).fit(z, y)
|
83
|
+
self
|
84
|
+
end
|
85
|
+
|
86
|
+
# Calculate confidence scores for samples.
|
87
|
+
#
|
88
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to compute the scores.
|
89
|
+
# @return [Numo::DFloat] (shape: [n_samples, n_classes]) Confidence score per sample.
|
90
|
+
def decision_function(x)
|
91
|
+
z = transform(x)
|
92
|
+
@model.decision_function(z)
|
93
|
+
end
|
94
|
+
|
95
|
+
# Predict class labels for samples.
|
96
|
+
#
|
97
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels.
|
98
|
+
# @return [Numo::Int32] (shape: [n_samples]) Predicted class label per sample.
|
99
|
+
def predict(x)
|
100
|
+
z = transform(x)
|
101
|
+
@model.predict(z)
|
102
|
+
end
|
103
|
+
|
104
|
+
# Transform the given data with the learned model.
|
105
|
+
#
|
106
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The data to be transformed with the learned model.
|
107
|
+
# @return [Numo::DFloat] (shape: [n_samples, n_features + n_features * n_clusters]) The transformed data.
|
108
|
+
def transform(x)
|
109
|
+
clustering(x) if @cluster_centers.nil?
|
110
|
+
|
111
|
+
cluster_ids = assign_cluster_id(x)
|
112
|
+
|
113
|
+
x = expand_feature(x) if fit_bias?
|
114
|
+
|
115
|
+
n_samples, n_features = x.shape
|
116
|
+
z = Numo::DFloat.zeros(n_samples, n_features * (1 + @params[:n_clusters]))
|
117
|
+
z[true, 0...n_features] = 1.fdiv(Math.sqrt(@params[:reg_param_global])) * x
|
118
|
+
@params[:n_clusters].times do |n|
|
119
|
+
assigned_bits = cluster_ids.eq(n)
|
120
|
+
z[assigned_bits.where, n_features * (n + 1)...n_features * (n + 2)] = x[assigned_bits.where, true]
|
121
|
+
end
|
122
|
+
|
123
|
+
z
|
124
|
+
end
|
125
|
+
|
126
|
+
private
|
127
|
+
|
128
|
+
def linear_svc_params
|
129
|
+
@params.reject { |key, _| CLUSTERED_SVC_BINARY_PARAMS.include?(key) }.merge(fit_bias: false)
|
130
|
+
end
|
131
|
+
|
132
|
+
def clustering(x)
|
133
|
+
n_samples = x.shape[0]
|
134
|
+
sub_rng = @rng.dup
|
135
|
+
rand_id = Array.new(@params[:n_clusters]) { |_v| sub_rng.rand(0...n_samples) }
|
136
|
+
@cluster_centers = x[rand_id, true].dup
|
137
|
+
|
138
|
+
@params[:max_iter_kmeans].times do |_t|
|
139
|
+
center_ids = assign_cluster_id(x)
|
140
|
+
old_centers = @cluster_centers.dup
|
141
|
+
@params[:n_clusters].times do |n|
|
142
|
+
assigned_bits = center_ids.eq(n)
|
143
|
+
@cluster_centers[n, true] = x[assigned_bits.where, true].mean(axis: 0) if assigned_bits.count.positive?
|
144
|
+
end
|
145
|
+
error = Numo::NMath.sqrt(((old_centers - @cluster_centers)**2).sum(axis: 1)).mean
|
146
|
+
break if error <= @params[:tol_kmeans]
|
147
|
+
end
|
148
|
+
end
|
149
|
+
|
150
|
+
def assign_cluster_id(x)
|
151
|
+
distance_matrix = ::Rumale::PairwiseMetric.euclidean_distance(x, @cluster_centers)
|
152
|
+
distance_matrix.min_index(axis: 1) - Numo::Int32[*0.step(distance_matrix.size - 1, @cluster_centers.shape[0])]
|
153
|
+
end
|
154
|
+
|
155
|
+
def expand_feature(x)
|
156
|
+
n_samples = x.shape[0]
|
157
|
+
Numo::NArray.hstack([x, Numo::DFloat.ones([n_samples, 1]) * @params[:bias_scale]])
|
158
|
+
end
|
159
|
+
|
160
|
+
def fit_bias?
|
161
|
+
return false if @params[:fit_bias].nil? || @params[:fit_bias] == false
|
162
|
+
|
163
|
+
true
|
164
|
+
end
|
165
|
+
|
166
|
+
CLUSTERED_SVC_BINARY_PARAMS = %i[n_clusters reg_param_global max_iter_kmeans tol_kmeans].freeze
|
167
|
+
|
168
|
+
private_constant :CLUSTERED_SVC_BINARY_PARAMS
|
169
|
+
end
|
170
|
+
end
|
171
|
+
end
|
data/lib/rumale/svm/version.rb
CHANGED
data/lib/rumale/svm.rb
CHANGED
@@ -0,0 +1,29 @@
|
|
1
|
+
# TypeProf 0.21.8
|
2
|
+
|
3
|
+
# Classes
|
4
|
+
module Rumale
|
5
|
+
module SVM
|
6
|
+
class ClusteredSVC
|
7
|
+
@params: {n_clusters: Integer, reg_param_global: Float, max_iter_kmeans: Integer, tol_kmeans: Float, penalty: String, loss: String, dual: bool, reg_param: Float, fit_bias: bool, bias_scale: Float, tol: Float, verbose: bool, random_seed: Integer}
|
8
|
+
@rng: Random
|
9
|
+
|
10
|
+
attr_reader model: Rumale::SVM::LinearSVC
|
11
|
+
attr_accessor cluster_centers: Numo::DFloat
|
12
|
+
def initialize: (?n_clusters: Integer, ?reg_param_global: Float, ?max_iter_kmeans: Integer, ?tol_kmeans: Float, ?penalty: String, ?loss: String, ?dual: bool, ?reg_param: Float, ?fit_bias: bool, ?bias_scale: Float, ?tol: Float, ?verbose: bool, ?random_seed: (nil | Integer)) -> void
|
13
|
+
def fit: (Numo::DFloat x, Numo::Int32 y) -> ClusteredSVC
|
14
|
+
def decision_function: (Numo::DFloat x) -> Numo::DFloat
|
15
|
+
def predict: (Numo::DFloat x) -> Numo::Int32
|
16
|
+
def transform: (Numo::DFloat x) -> Numo::DFloat
|
17
|
+
|
18
|
+
private
|
19
|
+
|
20
|
+
def linear_svc_params: -> (Hash[:bias_scale | :dual | :fit_bias | :loss | :max_iter_kmeans | :n_clusters | :penalty | :random_seed | :reg_param | :reg_param_global | :tol | :tol_kmeans | :verbose, Float | Integer | String | bool])
|
21
|
+
def clustering: (Numo::DFloat x) -> void
|
22
|
+
def assign_cluster_id: (Numo::DFloat x) -> Numo::Int32
|
23
|
+
def expand_feature: (Numo::DFloat x) -> Numo::DFloat
|
24
|
+
def fit_bias?: -> bool
|
25
|
+
|
26
|
+
CLUSTERED_SVC_BINARY_PARAMS: [:n_clusters, :reg_param_global, :max_iter_kmeans, :tol_kmeans]
|
27
|
+
end
|
28
|
+
end
|
29
|
+
end
|
metadata
CHANGED
@@ -1,14 +1,13 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: rumale-svm
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.12.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- yoshoku
|
8
|
-
autorequire:
|
9
8
|
bindir: exe
|
10
9
|
cert_chain: []
|
11
|
-
date:
|
10
|
+
date: 2025-01-02 00:00:00.000000000 Z
|
12
11
|
dependencies:
|
13
12
|
- !ruby/object:Gem::Dependency
|
14
13
|
name: numo-liblinear
|
@@ -42,14 +41,14 @@ dependencies:
|
|
42
41
|
name: rumale-core
|
43
42
|
requirement: !ruby/object:Gem::Requirement
|
44
43
|
requirements:
|
45
|
-
- - "
|
44
|
+
- - ">="
|
46
45
|
- !ruby/object:Gem::Version
|
47
46
|
version: '0.24'
|
48
47
|
type: :runtime
|
49
48
|
prerelease: false
|
50
49
|
version_requirements: !ruby/object:Gem::Requirement
|
51
50
|
requirements:
|
52
|
-
- - "
|
51
|
+
- - ">="
|
53
52
|
- !ruby/object:Gem::Version
|
54
53
|
version: '0.24'
|
55
54
|
description: 'Rumale::SVM provides support vector machine algorithms using LIBSVM
|
@@ -66,6 +65,7 @@ files:
|
|
66
65
|
- LICENSE.txt
|
67
66
|
- README.md
|
68
67
|
- lib/rumale/svm.rb
|
68
|
+
- lib/rumale/svm/clustered_svc.rb
|
69
69
|
- lib/rumale/svm/linear_one_class_svm.rb
|
70
70
|
- lib/rumale/svm/linear_svc.rb
|
71
71
|
- lib/rumale/svm/linear_svr.rb
|
@@ -79,6 +79,7 @@ files:
|
|
79
79
|
- lib/rumale/svm/svr.rb
|
80
80
|
- lib/rumale/svm/version.rb
|
81
81
|
- sig/rumale/svm.rbs
|
82
|
+
- sig/rumale/svm/clustered_svc.rbs
|
82
83
|
- sig/rumale/svm/linear_one_class_svm.rbs
|
83
84
|
- sig/rumale/svm/linear_svc.rbs
|
84
85
|
- sig/rumale/svm/linear_svr.rbs
|
@@ -99,7 +100,6 @@ metadata:
|
|
99
100
|
changelog_uri: https://github.com/yoshoku/rumale-svm/blob/main/CHANGELOG.md
|
100
101
|
documentation_uri: https://yoshoku.github.io/rumale-svm/doc/
|
101
102
|
rubygems_mfa_required: 'true'
|
102
|
-
post_install_message:
|
103
103
|
rdoc_options: []
|
104
104
|
require_paths:
|
105
105
|
- lib
|
@@ -114,8 +114,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
114
114
|
- !ruby/object:Gem::Version
|
115
115
|
version: '0'
|
116
116
|
requirements: []
|
117
|
-
rubygems_version: 3.
|
118
|
-
signing_key:
|
117
|
+
rubygems_version: 3.6.2
|
119
118
|
specification_version: 4
|
120
119
|
summary: Rumale::SVM provides support vector machine algorithms using LIBSVM and LIBLINEAR
|
121
120
|
with Rumale interface.
|