rumale-svm 0.10.0 → 0.12.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/LICENSE.txt +1 -1
- data/lib/rumale/svm/clustered_svc.rb +171 -0
- data/lib/rumale/svm/version.rb +1 -1
- data/lib/rumale/svm.rb +1 -0
- data/sig/rumale/svm/clustered_svc.rbs +29 -0
- metadata +7 -8
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: eb43a42ed2be607317dde53e4f81788309aa27e28d1c169419e6e0d0cce073af
|
4
|
+
data.tar.gz: fe1ba15cd33b364a73b249e83c1b4a6937f50b5d7424db49957c2b4ca87c9267
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 4aad8ec437d6522fbd3d90167a80a463d11b19ff60ee00d90d4a4f8810ee25a615592f2d5caeaf22fe9504aad9189f87c332626c27ef1f8de0f70c15b1908b51
|
7
|
+
data.tar.gz: d2b12aab3a83ba1ad71d89e49499b007f3054dbc0e49402b4a00ffc238450ade1701edb46bbd7d276dd231cd471831d2db62c210e4b76a096dc02a072c202796
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,9 @@
|
|
1
|
+
# [[0.12.0](https://github.com/yoshoku/rumale-svm/compare/v0.11.0...v0.12.0)]
|
2
|
+
- Fix the version specification of rumale-core gem.
|
3
|
+
|
4
|
+
# [[0.11.0](https://github.com/yoshoku/rumale-svm/compare/v0.10.0...v0.11.0)]
|
5
|
+
- Add Rumale::SVM::ClusteredSVC that is classifier with clustered support vector machine.
|
6
|
+
|
1
7
|
# 0.10.0
|
2
8
|
- Add Rumale::SVM::RandomRecursiveSVC that is classifier with random recursive support vector machine.
|
3
9
|
- Add type declaration files for RandomRecursiveSVC and LocallyLinearSVC.
|
data/LICENSE.txt
CHANGED
@@ -0,0 +1,171 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'rumale/base/estimator'
|
4
|
+
require 'rumale/base/classifier'
|
5
|
+
require 'rumale/pairwise_metric'
|
6
|
+
require 'rumale/validation'
|
7
|
+
require 'rumale/svm/linear_svc'
|
8
|
+
|
9
|
+
module Rumale
|
10
|
+
module SVM
|
11
|
+
# ClusteredSVC is a class that implements Clustered Support Vector Classifier.
|
12
|
+
#
|
13
|
+
# @example
|
14
|
+
# require 'rumale/svm'
|
15
|
+
#
|
16
|
+
# estimator = Rumale::SVM::ClusteredSVC.new(n_clusters: 16, reg_param_global: 1.0, random_seed: 1)
|
17
|
+
# estimator.fit(training_samples, training_labels)
|
18
|
+
# results = estimator.predict(testing_samples)
|
19
|
+
#
|
20
|
+
# *Reference*
|
21
|
+
# - Gu, Q., and Han, J., "Clustered Support Vector Machines," In Proc. AISTATS'13, pp. 307--315, 2013.
|
22
|
+
class ClusteredSVC < Rumale::Base::Estimator
|
23
|
+
include Rumale::Base::Classifier
|
24
|
+
|
25
|
+
# Return the classifier.
|
26
|
+
# @return [LinearSVC]
|
27
|
+
attr_reader :model
|
28
|
+
|
29
|
+
# Return the centroids.
|
30
|
+
# @return [Numo::DFloat] (shape: [n_clusters, n_features])
|
31
|
+
attr_accessor :cluster_centers
|
32
|
+
|
33
|
+
# Create a new classifier with Random Recursive Support Vector Machine.
|
34
|
+
#
|
35
|
+
# @param n_clusters [Integer] The number of clusters.
|
36
|
+
# @param reg_param_global [Float] The regularization parameter for global reference vector.
|
37
|
+
# @param max_iter_kmeans [Integer] The maximum number of iterations for k-means clustering.
|
38
|
+
# @param tol_kmeans [Float] The tolerance of termination criterion for k-means clustering.
|
39
|
+
# @param penalty [String] The type of norm used in the penalization ('l2' or 'l1').
|
40
|
+
# @param loss [String] The type of loss function ('squared_hinge' or 'hinge').
|
41
|
+
# This parameter is ignored if penalty = 'l1'.
|
42
|
+
# @param dual [Boolean] The flag indicating whether to solve dual optimization problem.
|
43
|
+
# When n_samples > n_features, dual = false is more preferable.
|
44
|
+
# This parameter is ignored if loss = 'hinge'.
|
45
|
+
# @param reg_param [Float] The regularization parameter.
|
46
|
+
# @param fit_bias [Boolean] The flag indicating whether to fit the bias term.
|
47
|
+
# @param bias_scale [Float] The scale of the bias term.
|
48
|
+
# This parameter is ignored if fit_bias = false.
|
49
|
+
# @param tol [Float] The tolerance of termination criterion.
|
50
|
+
# @param verbose [Boolean] The flag indicating whether to output learning process message
|
51
|
+
# @param random_seed [Integer/Nil] The seed value using to initialize the random generator.
|
52
|
+
def initialize(n_clusters: 8, reg_param_global: 1.0, max_iter_kmeans: 100, tol_kmeans: 1e-6, # rubocop:disable Metrics/ParameterLists
|
53
|
+
penalty: 'l2', loss: 'squared_hinge', dual: true, reg_param: 1.0,
|
54
|
+
fit_bias: true, bias_scale: 1.0, tol: 1e-3, verbose: false, random_seed: nil)
|
55
|
+
super()
|
56
|
+
@params = {
|
57
|
+
n_clusters: n_clusters,
|
58
|
+
reg_param_global: reg_param_global,
|
59
|
+
max_iter_kmeans: max_iter_kmeans,
|
60
|
+
tol_kmeans: tol_kmeans,
|
61
|
+
penalty: penalty == 'l1' ? 'l1' : 'l2',
|
62
|
+
loss: loss == 'hinge' ? 'hinge' : 'squared_hinge',
|
63
|
+
dual: dual,
|
64
|
+
reg_param: reg_param.to_f,
|
65
|
+
fit_bias: fit_bias,
|
66
|
+
bias_scale: bias_scale.to_f,
|
67
|
+
tol: tol.to_f,
|
68
|
+
verbose: verbose,
|
69
|
+
random_seed: random_seed || Random.rand(4_294_967_295)
|
70
|
+
}
|
71
|
+
@rng = Random.new(@params[:random_seed])
|
72
|
+
@cluster_centers = nil
|
73
|
+
end
|
74
|
+
|
75
|
+
# Fit the model with given training data.
|
76
|
+
#
|
77
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
|
78
|
+
# @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
|
79
|
+
# @return [ClusteredSVC] The learned classifier itself.
|
80
|
+
def fit(x, y)
|
81
|
+
z = transform(x)
|
82
|
+
@model = LinearSVC.new(**linear_svc_params).fit(z, y)
|
83
|
+
self
|
84
|
+
end
|
85
|
+
|
86
|
+
# Calculate confidence scores for samples.
|
87
|
+
#
|
88
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to compute the scores.
|
89
|
+
# @return [Numo::DFloat] (shape: [n_samples, n_classes]) Confidence score per sample.
|
90
|
+
def decision_function(x)
|
91
|
+
z = transform(x)
|
92
|
+
@model.decision_function(z)
|
93
|
+
end
|
94
|
+
|
95
|
+
# Predict class labels for samples.
|
96
|
+
#
|
97
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels.
|
98
|
+
# @return [Numo::Int32] (shape: [n_samples]) Predicted class label per sample.
|
99
|
+
def predict(x)
|
100
|
+
z = transform(x)
|
101
|
+
@model.predict(z)
|
102
|
+
end
|
103
|
+
|
104
|
+
# Transform the given data with the learned model.
|
105
|
+
#
|
106
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The data to be transformed with the learned model.
|
107
|
+
# @return [Numo::DFloat] (shape: [n_samples, n_features + n_features * n_clusters]) The transformed data.
|
108
|
+
def transform(x)
|
109
|
+
clustering(x) if @cluster_centers.nil?
|
110
|
+
|
111
|
+
cluster_ids = assign_cluster_id(x)
|
112
|
+
|
113
|
+
x = expand_feature(x) if fit_bias?
|
114
|
+
|
115
|
+
n_samples, n_features = x.shape
|
116
|
+
z = Numo::DFloat.zeros(n_samples, n_features * (1 + @params[:n_clusters]))
|
117
|
+
z[true, 0...n_features] = 1.fdiv(Math.sqrt(@params[:reg_param_global])) * x
|
118
|
+
@params[:n_clusters].times do |n|
|
119
|
+
assigned_bits = cluster_ids.eq(n)
|
120
|
+
z[assigned_bits.where, n_features * (n + 1)...n_features * (n + 2)] = x[assigned_bits.where, true]
|
121
|
+
end
|
122
|
+
|
123
|
+
z
|
124
|
+
end
|
125
|
+
|
126
|
+
private
|
127
|
+
|
128
|
+
def linear_svc_params
|
129
|
+
@params.reject { |key, _| CLUSTERED_SVC_BINARY_PARAMS.include?(key) }.merge(fit_bias: false)
|
130
|
+
end
|
131
|
+
|
132
|
+
def clustering(x)
|
133
|
+
n_samples = x.shape[0]
|
134
|
+
sub_rng = @rng.dup
|
135
|
+
rand_id = Array.new(@params[:n_clusters]) { |_v| sub_rng.rand(0...n_samples) }
|
136
|
+
@cluster_centers = x[rand_id, true].dup
|
137
|
+
|
138
|
+
@params[:max_iter_kmeans].times do |_t|
|
139
|
+
center_ids = assign_cluster_id(x)
|
140
|
+
old_centers = @cluster_centers.dup
|
141
|
+
@params[:n_clusters].times do |n|
|
142
|
+
assigned_bits = center_ids.eq(n)
|
143
|
+
@cluster_centers[n, true] = x[assigned_bits.where, true].mean(axis: 0) if assigned_bits.count.positive?
|
144
|
+
end
|
145
|
+
error = Numo::NMath.sqrt(((old_centers - @cluster_centers)**2).sum(axis: 1)).mean
|
146
|
+
break if error <= @params[:tol_kmeans]
|
147
|
+
end
|
148
|
+
end
|
149
|
+
|
150
|
+
def assign_cluster_id(x)
|
151
|
+
distance_matrix = ::Rumale::PairwiseMetric.euclidean_distance(x, @cluster_centers)
|
152
|
+
distance_matrix.min_index(axis: 1) - Numo::Int32[*0.step(distance_matrix.size - 1, @cluster_centers.shape[0])]
|
153
|
+
end
|
154
|
+
|
155
|
+
def expand_feature(x)
|
156
|
+
n_samples = x.shape[0]
|
157
|
+
Numo::NArray.hstack([x, Numo::DFloat.ones([n_samples, 1]) * @params[:bias_scale]])
|
158
|
+
end
|
159
|
+
|
160
|
+
def fit_bias?
|
161
|
+
return false if @params[:fit_bias].nil? || @params[:fit_bias] == false
|
162
|
+
|
163
|
+
true
|
164
|
+
end
|
165
|
+
|
166
|
+
CLUSTERED_SVC_BINARY_PARAMS = %i[n_clusters reg_param_global max_iter_kmeans tol_kmeans].freeze
|
167
|
+
|
168
|
+
private_constant :CLUSTERED_SVC_BINARY_PARAMS
|
169
|
+
end
|
170
|
+
end
|
171
|
+
end
|
data/lib/rumale/svm/version.rb
CHANGED
data/lib/rumale/svm.rb
CHANGED
@@ -0,0 +1,29 @@
|
|
1
|
+
# TypeProf 0.21.8
|
2
|
+
|
3
|
+
# Classes
|
4
|
+
module Rumale
|
5
|
+
module SVM
|
6
|
+
class ClusteredSVC
|
7
|
+
@params: {n_clusters: Integer, reg_param_global: Float, max_iter_kmeans: Integer, tol_kmeans: Float, penalty: String, loss: String, dual: bool, reg_param: Float, fit_bias: bool, bias_scale: Float, tol: Float, verbose: bool, random_seed: Integer}
|
8
|
+
@rng: Random
|
9
|
+
|
10
|
+
attr_reader model: Rumale::SVM::LinearSVC
|
11
|
+
attr_accessor cluster_centers: Numo::DFloat
|
12
|
+
def initialize: (?n_clusters: Integer, ?reg_param_global: Float, ?max_iter_kmeans: Integer, ?tol_kmeans: Float, ?penalty: String, ?loss: String, ?dual: bool, ?reg_param: Float, ?fit_bias: bool, ?bias_scale: Float, ?tol: Float, ?verbose: bool, ?random_seed: (nil | Integer)) -> void
|
13
|
+
def fit: (Numo::DFloat x, Numo::Int32 y) -> ClusteredSVC
|
14
|
+
def decision_function: (Numo::DFloat x) -> Numo::DFloat
|
15
|
+
def predict: (Numo::DFloat x) -> Numo::Int32
|
16
|
+
def transform: (Numo::DFloat x) -> Numo::DFloat
|
17
|
+
|
18
|
+
private
|
19
|
+
|
20
|
+
def linear_svc_params: -> (Hash[:bias_scale | :dual | :fit_bias | :loss | :max_iter_kmeans | :n_clusters | :penalty | :random_seed | :reg_param | :reg_param_global | :tol | :tol_kmeans | :verbose, Float | Integer | String | bool])
|
21
|
+
def clustering: (Numo::DFloat x) -> void
|
22
|
+
def assign_cluster_id: (Numo::DFloat x) -> Numo::Int32
|
23
|
+
def expand_feature: (Numo::DFloat x) -> Numo::DFloat
|
24
|
+
def fit_bias?: -> bool
|
25
|
+
|
26
|
+
CLUSTERED_SVC_BINARY_PARAMS: [:n_clusters, :reg_param_global, :max_iter_kmeans, :tol_kmeans]
|
27
|
+
end
|
28
|
+
end
|
29
|
+
end
|
metadata
CHANGED
@@ -1,14 +1,13 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: rumale-svm
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.12.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- yoshoku
|
8
|
-
autorequire:
|
9
8
|
bindir: exe
|
10
9
|
cert_chain: []
|
11
|
-
date:
|
10
|
+
date: 2025-01-02 00:00:00.000000000 Z
|
12
11
|
dependencies:
|
13
12
|
- !ruby/object:Gem::Dependency
|
14
13
|
name: numo-liblinear
|
@@ -42,14 +41,14 @@ dependencies:
|
|
42
41
|
name: rumale-core
|
43
42
|
requirement: !ruby/object:Gem::Requirement
|
44
43
|
requirements:
|
45
|
-
- - "
|
44
|
+
- - ">="
|
46
45
|
- !ruby/object:Gem::Version
|
47
46
|
version: '0.24'
|
48
47
|
type: :runtime
|
49
48
|
prerelease: false
|
50
49
|
version_requirements: !ruby/object:Gem::Requirement
|
51
50
|
requirements:
|
52
|
-
- - "
|
51
|
+
- - ">="
|
53
52
|
- !ruby/object:Gem::Version
|
54
53
|
version: '0.24'
|
55
54
|
description: 'Rumale::SVM provides support vector machine algorithms using LIBSVM
|
@@ -66,6 +65,7 @@ files:
|
|
66
65
|
- LICENSE.txt
|
67
66
|
- README.md
|
68
67
|
- lib/rumale/svm.rb
|
68
|
+
- lib/rumale/svm/clustered_svc.rb
|
69
69
|
- lib/rumale/svm/linear_one_class_svm.rb
|
70
70
|
- lib/rumale/svm/linear_svc.rb
|
71
71
|
- lib/rumale/svm/linear_svr.rb
|
@@ -79,6 +79,7 @@ files:
|
|
79
79
|
- lib/rumale/svm/svr.rb
|
80
80
|
- lib/rumale/svm/version.rb
|
81
81
|
- sig/rumale/svm.rbs
|
82
|
+
- sig/rumale/svm/clustered_svc.rbs
|
82
83
|
- sig/rumale/svm/linear_one_class_svm.rbs
|
83
84
|
- sig/rumale/svm/linear_svc.rbs
|
84
85
|
- sig/rumale/svm/linear_svr.rbs
|
@@ -99,7 +100,6 @@ metadata:
|
|
99
100
|
changelog_uri: https://github.com/yoshoku/rumale-svm/blob/main/CHANGELOG.md
|
100
101
|
documentation_uri: https://yoshoku.github.io/rumale-svm/doc/
|
101
102
|
rubygems_mfa_required: 'true'
|
102
|
-
post_install_message:
|
103
103
|
rdoc_options: []
|
104
104
|
require_paths:
|
105
105
|
- lib
|
@@ -114,8 +114,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
114
114
|
- !ruby/object:Gem::Version
|
115
115
|
version: '0'
|
116
116
|
requirements: []
|
117
|
-
rubygems_version: 3.
|
118
|
-
signing_key:
|
117
|
+
rubygems_version: 3.6.2
|
119
118
|
specification_version: 4
|
120
119
|
summary: Rumale::SVM provides support vector machine algorithms using LIBSVM and LIBLINEAR
|
121
120
|
with Rumale interface.
|