rumale-svm 0.10.0 → 0.12.0

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: d027826bfde557a8e724b62252182549d71a3bc90ebcfa7b3e5f4a4f915a553e
4
- data.tar.gz: ba0baf950d204dcbf993b41c99ca8a79a63b0ffac87fe5204211c7b56759374e
3
+ metadata.gz: eb43a42ed2be607317dde53e4f81788309aa27e28d1c169419e6e0d0cce073af
4
+ data.tar.gz: fe1ba15cd33b364a73b249e83c1b4a6937f50b5d7424db49957c2b4ca87c9267
5
5
  SHA512:
6
- metadata.gz: 65f9a78800033bdbc0354d146cf6150d35aa4924c07164a7bfe578704642d4cde0d49a604fcf5dd75a135462282ee2e18fbcf9157773ce8d827c5a671be6eaa6
7
- data.tar.gz: c3d8da2ad2790c8cc656194c1dd0a083a5dbc364ef3b14c768cf5edcb83449d4ef0d61f3d00d8b358171c492e09f2053bd73c7e3b56418234d996fd70945c23b
6
+ metadata.gz: 4aad8ec437d6522fbd3d90167a80a463d11b19ff60ee00d90d4a4f8810ee25a615592f2d5caeaf22fe9504aad9189f87c332626c27ef1f8de0f70c15b1908b51
7
+ data.tar.gz: d2b12aab3a83ba1ad71d89e49499b007f3054dbc0e49402b4a00ffc238450ade1701edb46bbd7d276dd231cd471831d2db62c210e4b76a096dc02a072c202796
data/CHANGELOG.md CHANGED
@@ -1,3 +1,9 @@
1
+ # [[0.12.0](https://github.com/yoshoku/rumale-svm/compare/v0.11.0...v0.12.0)]
2
+ - Fix the version specification of rumale-core gem.
3
+
4
+ # [[0.11.0](https://github.com/yoshoku/rumale-svm/compare/v0.10.0...v0.11.0)]
5
+ - Add Rumale::SVM::ClusteredSVC that is classifier with clustered support vector machine.
6
+
1
7
  # 0.10.0
2
8
  - Add Rumale::SVM::RandomRecursiveSVC that is classifier with random recursive support vector machine.
3
9
  - Add type declaration files for RandomRecursiveSVC and LocallyLinearSVC.
data/LICENSE.txt CHANGED
@@ -1,4 +1,4 @@
1
- Copyright (c) 2019-2023 Atsushi Tatsuma
1
+ Copyright (c) 2019-2025 Atsushi Tatsuma
2
2
  All rights reserved.
3
3
 
4
4
  Redistribution and use in source and binary forms, with or without
@@ -0,0 +1,171 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/estimator'
4
+ require 'rumale/base/classifier'
5
+ require 'rumale/pairwise_metric'
6
+ require 'rumale/validation'
7
+ require 'rumale/svm/linear_svc'
8
+
9
+ module Rumale
10
+ module SVM
11
+ # ClusteredSVC is a class that implements Clustered Support Vector Classifier.
12
+ #
13
+ # @example
14
+ # require 'rumale/svm'
15
+ #
16
+ # estimator = Rumale::SVM::ClusteredSVC.new(n_clusters: 16, reg_param_global: 1.0, random_seed: 1)
17
+ # estimator.fit(training_samples, training_labels)
18
+ # results = estimator.predict(testing_samples)
19
+ #
20
+ # *Reference*
21
+ # - Gu, Q., and Han, J., "Clustered Support Vector Machines," In Proc. AISTATS'13, pp. 307--315, 2013.
22
+ class ClusteredSVC < Rumale::Base::Estimator
23
+ include Rumale::Base::Classifier
24
+
25
+ # Return the classifier.
26
+ # @return [LinearSVC]
27
+ attr_reader :model
28
+
29
+ # Return the centroids.
30
+ # @return [Numo::DFloat] (shape: [n_clusters, n_features])
31
+ attr_accessor :cluster_centers
32
+
33
+ # Create a new classifier with Random Recursive Support Vector Machine.
34
+ #
35
+ # @param n_clusters [Integer] The number of clusters.
36
+ # @param reg_param_global [Float] The regularization parameter for global reference vector.
37
+ # @param max_iter_kmeans [Integer] The maximum number of iterations for k-means clustering.
38
+ # @param tol_kmeans [Float] The tolerance of termination criterion for k-means clustering.
39
+ # @param penalty [String] The type of norm used in the penalization ('l2' or 'l1').
40
+ # @param loss [String] The type of loss function ('squared_hinge' or 'hinge').
41
+ # This parameter is ignored if penalty = 'l1'.
42
+ # @param dual [Boolean] The flag indicating whether to solve dual optimization problem.
43
+ # When n_samples > n_features, dual = false is more preferable.
44
+ # This parameter is ignored if loss = 'hinge'.
45
+ # @param reg_param [Float] The regularization parameter.
46
+ # @param fit_bias [Boolean] The flag indicating whether to fit the bias term.
47
+ # @param bias_scale [Float] The scale of the bias term.
48
+ # This parameter is ignored if fit_bias = false.
49
+ # @param tol [Float] The tolerance of termination criterion.
50
+ # @param verbose [Boolean] The flag indicating whether to output learning process message
51
+ # @param random_seed [Integer/Nil] The seed value using to initialize the random generator.
52
+ def initialize(n_clusters: 8, reg_param_global: 1.0, max_iter_kmeans: 100, tol_kmeans: 1e-6, # rubocop:disable Metrics/ParameterLists
53
+ penalty: 'l2', loss: 'squared_hinge', dual: true, reg_param: 1.0,
54
+ fit_bias: true, bias_scale: 1.0, tol: 1e-3, verbose: false, random_seed: nil)
55
+ super()
56
+ @params = {
57
+ n_clusters: n_clusters,
58
+ reg_param_global: reg_param_global,
59
+ max_iter_kmeans: max_iter_kmeans,
60
+ tol_kmeans: tol_kmeans,
61
+ penalty: penalty == 'l1' ? 'l1' : 'l2',
62
+ loss: loss == 'hinge' ? 'hinge' : 'squared_hinge',
63
+ dual: dual,
64
+ reg_param: reg_param.to_f,
65
+ fit_bias: fit_bias,
66
+ bias_scale: bias_scale.to_f,
67
+ tol: tol.to_f,
68
+ verbose: verbose,
69
+ random_seed: random_seed || Random.rand(4_294_967_295)
70
+ }
71
+ @rng = Random.new(@params[:random_seed])
72
+ @cluster_centers = nil
73
+ end
74
+
75
+ # Fit the model with given training data.
76
+ #
77
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
78
+ # @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
79
+ # @return [ClusteredSVC] The learned classifier itself.
80
+ def fit(x, y)
81
+ z = transform(x)
82
+ @model = LinearSVC.new(**linear_svc_params).fit(z, y)
83
+ self
84
+ end
85
+
86
+ # Calculate confidence scores for samples.
87
+ #
88
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to compute the scores.
89
+ # @return [Numo::DFloat] (shape: [n_samples, n_classes]) Confidence score per sample.
90
+ def decision_function(x)
91
+ z = transform(x)
92
+ @model.decision_function(z)
93
+ end
94
+
95
+ # Predict class labels for samples.
96
+ #
97
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels.
98
+ # @return [Numo::Int32] (shape: [n_samples]) Predicted class label per sample.
99
+ def predict(x)
100
+ z = transform(x)
101
+ @model.predict(z)
102
+ end
103
+
104
+ # Transform the given data with the learned model.
105
+ #
106
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The data to be transformed with the learned model.
107
+ # @return [Numo::DFloat] (shape: [n_samples, n_features + n_features * n_clusters]) The transformed data.
108
+ def transform(x)
109
+ clustering(x) if @cluster_centers.nil?
110
+
111
+ cluster_ids = assign_cluster_id(x)
112
+
113
+ x = expand_feature(x) if fit_bias?
114
+
115
+ n_samples, n_features = x.shape
116
+ z = Numo::DFloat.zeros(n_samples, n_features * (1 + @params[:n_clusters]))
117
+ z[true, 0...n_features] = 1.fdiv(Math.sqrt(@params[:reg_param_global])) * x
118
+ @params[:n_clusters].times do |n|
119
+ assigned_bits = cluster_ids.eq(n)
120
+ z[assigned_bits.where, n_features * (n + 1)...n_features * (n + 2)] = x[assigned_bits.where, true]
121
+ end
122
+
123
+ z
124
+ end
125
+
126
+ private
127
+
128
+ def linear_svc_params
129
+ @params.reject { |key, _| CLUSTERED_SVC_BINARY_PARAMS.include?(key) }.merge(fit_bias: false)
130
+ end
131
+
132
+ def clustering(x)
133
+ n_samples = x.shape[0]
134
+ sub_rng = @rng.dup
135
+ rand_id = Array.new(@params[:n_clusters]) { |_v| sub_rng.rand(0...n_samples) }
136
+ @cluster_centers = x[rand_id, true].dup
137
+
138
+ @params[:max_iter_kmeans].times do |_t|
139
+ center_ids = assign_cluster_id(x)
140
+ old_centers = @cluster_centers.dup
141
+ @params[:n_clusters].times do |n|
142
+ assigned_bits = center_ids.eq(n)
143
+ @cluster_centers[n, true] = x[assigned_bits.where, true].mean(axis: 0) if assigned_bits.count.positive?
144
+ end
145
+ error = Numo::NMath.sqrt(((old_centers - @cluster_centers)**2).sum(axis: 1)).mean
146
+ break if error <= @params[:tol_kmeans]
147
+ end
148
+ end
149
+
150
+ def assign_cluster_id(x)
151
+ distance_matrix = ::Rumale::PairwiseMetric.euclidean_distance(x, @cluster_centers)
152
+ distance_matrix.min_index(axis: 1) - Numo::Int32[*0.step(distance_matrix.size - 1, @cluster_centers.shape[0])]
153
+ end
154
+
155
+ def expand_feature(x)
156
+ n_samples = x.shape[0]
157
+ Numo::NArray.hstack([x, Numo::DFloat.ones([n_samples, 1]) * @params[:bias_scale]])
158
+ end
159
+
160
+ def fit_bias?
161
+ return false if @params[:fit_bias].nil? || @params[:fit_bias] == false
162
+
163
+ true
164
+ end
165
+
166
+ CLUSTERED_SVC_BINARY_PARAMS = %i[n_clusters reg_param_global max_iter_kmeans tol_kmeans].freeze
167
+
168
+ private_constant :CLUSTERED_SVC_BINARY_PARAMS
169
+ end
170
+ end
171
+ end
@@ -5,6 +5,6 @@ module Rumale
5
5
  # This module consists of Rumale interfaces for suppor vector machine algorithms using LIBSVM and LIBLINEAR.
6
6
  module SVM
7
7
  # The version of Rumale::SVM you are using.
8
- VERSION = '0.10.0'
8
+ VERSION = '0.12.0'
9
9
  end
10
10
  end
data/lib/rumale/svm.rb CHANGED
@@ -12,3 +12,4 @@ require 'rumale/svm/logistic_regression'
12
12
  require 'rumale/svm/linear_one_class_svm'
13
13
  require 'rumale/svm/locally_linear_svc'
14
14
  require 'rumale/svm/random_recursive_svc'
15
+ require 'rumale/svm/clustered_svc'
@@ -0,0 +1,29 @@
1
+ # TypeProf 0.21.8
2
+
3
+ # Classes
4
+ module Rumale
5
+ module SVM
6
+ class ClusteredSVC
7
+ @params: {n_clusters: Integer, reg_param_global: Float, max_iter_kmeans: Integer, tol_kmeans: Float, penalty: String, loss: String, dual: bool, reg_param: Float, fit_bias: bool, bias_scale: Float, tol: Float, verbose: bool, random_seed: Integer}
8
+ @rng: Random
9
+
10
+ attr_reader model: Rumale::SVM::LinearSVC
11
+ attr_accessor cluster_centers: Numo::DFloat
12
+ def initialize: (?n_clusters: Integer, ?reg_param_global: Float, ?max_iter_kmeans: Integer, ?tol_kmeans: Float, ?penalty: String, ?loss: String, ?dual: bool, ?reg_param: Float, ?fit_bias: bool, ?bias_scale: Float, ?tol: Float, ?verbose: bool, ?random_seed: (nil | Integer)) -> void
13
+ def fit: (Numo::DFloat x, Numo::Int32 y) -> ClusteredSVC
14
+ def decision_function: (Numo::DFloat x) -> Numo::DFloat
15
+ def predict: (Numo::DFloat x) -> Numo::Int32
16
+ def transform: (Numo::DFloat x) -> Numo::DFloat
17
+
18
+ private
19
+
20
+ def linear_svc_params: -> (Hash[:bias_scale | :dual | :fit_bias | :loss | :max_iter_kmeans | :n_clusters | :penalty | :random_seed | :reg_param | :reg_param_global | :tol | :tol_kmeans | :verbose, Float | Integer | String | bool])
21
+ def clustering: (Numo::DFloat x) -> void
22
+ def assign_cluster_id: (Numo::DFloat x) -> Numo::Int32
23
+ def expand_feature: (Numo::DFloat x) -> Numo::DFloat
24
+ def fit_bias?: -> bool
25
+
26
+ CLUSTERED_SVC_BINARY_PARAMS: [:n_clusters, :reg_param_global, :max_iter_kmeans, :tol_kmeans]
27
+ end
28
+ end
29
+ end
metadata CHANGED
@@ -1,14 +1,13 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: rumale-svm
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.10.0
4
+ version: 0.12.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
8
- autorequire:
9
8
  bindir: exe
10
9
  cert_chain: []
11
- date: 2023-12-02 00:00:00.000000000 Z
10
+ date: 2025-01-02 00:00:00.000000000 Z
12
11
  dependencies:
13
12
  - !ruby/object:Gem::Dependency
14
13
  name: numo-liblinear
@@ -42,14 +41,14 @@ dependencies:
42
41
  name: rumale-core
43
42
  requirement: !ruby/object:Gem::Requirement
44
43
  requirements:
45
- - - "~>"
44
+ - - ">="
46
45
  - !ruby/object:Gem::Version
47
46
  version: '0.24'
48
47
  type: :runtime
49
48
  prerelease: false
50
49
  version_requirements: !ruby/object:Gem::Requirement
51
50
  requirements:
52
- - - "~>"
51
+ - - ">="
53
52
  - !ruby/object:Gem::Version
54
53
  version: '0.24'
55
54
  description: 'Rumale::SVM provides support vector machine algorithms using LIBSVM
@@ -66,6 +65,7 @@ files:
66
65
  - LICENSE.txt
67
66
  - README.md
68
67
  - lib/rumale/svm.rb
68
+ - lib/rumale/svm/clustered_svc.rb
69
69
  - lib/rumale/svm/linear_one_class_svm.rb
70
70
  - lib/rumale/svm/linear_svc.rb
71
71
  - lib/rumale/svm/linear_svr.rb
@@ -79,6 +79,7 @@ files:
79
79
  - lib/rumale/svm/svr.rb
80
80
  - lib/rumale/svm/version.rb
81
81
  - sig/rumale/svm.rbs
82
+ - sig/rumale/svm/clustered_svc.rbs
82
83
  - sig/rumale/svm/linear_one_class_svm.rbs
83
84
  - sig/rumale/svm/linear_svc.rbs
84
85
  - sig/rumale/svm/linear_svr.rbs
@@ -99,7 +100,6 @@ metadata:
99
100
  changelog_uri: https://github.com/yoshoku/rumale-svm/blob/main/CHANGELOG.md
100
101
  documentation_uri: https://yoshoku.github.io/rumale-svm/doc/
101
102
  rubygems_mfa_required: 'true'
102
- post_install_message:
103
103
  rdoc_options: []
104
104
  require_paths:
105
105
  - lib
@@ -114,8 +114,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
114
114
  - !ruby/object:Gem::Version
115
115
  version: '0'
116
116
  requirements: []
117
- rubygems_version: 3.4.22
118
- signing_key:
117
+ rubygems_version: 3.6.2
119
118
  specification_version: 4
120
119
  summary: Rumale::SVM provides support vector machine algorithms using LIBSVM and LIBLINEAR
121
120
  with Rumale interface.