rumale-svm 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,160 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'numo/libsvm'
4
+ require 'rumale/base/base_estimator'
5
+ require 'rumale/base/regressor'
6
+
7
+ module Rumale
8
+ module SVM
9
+ # SVR is a class that provides Kernel Epsilon-Support Vector Regressor in LIBSVM with Rumale interface.
10
+ #
11
+ # @example
12
+ # estimator = Rumale::SVM::SVR.new(reg_param: 1.0, kernel: 'rbf', gamma: 10.0, random_seed: 1)
13
+ # estimator.fit(training_samples, traininig_target_values)
14
+ # results = estimator.predict(testing_samples)
15
+ class SVR
16
+ include Base::BaseEstimator
17
+ include Base::Regressor
18
+
19
+ # Create a new regressor with Kernel Epsilon-Support Vector Regressor.
20
+ #
21
+ # @param reg_param [Float] The regularization parameter.
22
+ # @param epsilon [Float] The epsilon parameter in loss function of espsilon-svr.
23
+ # @param kernel [String] The type of kernel function ('rbf', 'linear', 'poly', 'sigmoid', and 'precomputed').
24
+ # @param degree [Integer] The degree parameter in polynomial kernel function.
25
+ # @param gamma [Float] The gamma parameter in rbf/poly/sigmoid kernel function.
26
+ # @param coef0 [Float] The coefficient in poly/sigmoid kernel function.
27
+ # @param shrinking [Boolean] The flag indicating whether to use the shrinking heuristics.
28
+ # @param cache_size [Float] The cache memory size in MB.
29
+ # @param tol [Float] The tolerance of termination criterion.
30
+ # @param verbose [Boolean] The flag indicating whether to output learning process message
31
+ # @param random_seed [Integer/Nil] The seed value using to initialize the random generator.
32
+ def initialize(reg_param: 1.0, epsilon: 0.1, kernel: 'rbf', degree: 3, gamma: 1.0, coef0: 0.0,
33
+ shrinking: true, cache_size: 200.0, tol: 1e-3, verbose: false, random_seed: nil)
34
+ check_params_float(reg_param: reg_param, epsilon: epsilon, gamma: gamma, coef0: coef0, cache_size: cache_size, tol: tol)
35
+ check_params_integer(degree: degree)
36
+ check_params_boolean(shrinking: shrinking, verbose: verbose)
37
+ check_params_type_or_nil(Integer, random_seed: random_seed)
38
+ @params = {}
39
+ @params[:reg_param] = reg_param
40
+ @params[:epsilon] = epsilon
41
+ @params[:kernel] = kernel
42
+ @params[:degree] = degree
43
+ @params[:gamma] = gamma
44
+ @params[:coef0] = coef0
45
+ @params[:shrinking] = shrinking
46
+ @params[:cache_size] = cache_size
47
+ @params[:tol] = tol
48
+ @params[:verbose] = verbose
49
+ @params[:random_seed] = random_seed
50
+ @model = nil
51
+ end
52
+
53
+ # Fit the model with given training data.
54
+ #
55
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
56
+ # If the kernel is 'precomputed', x must be a square distance matrix (shape: [n_samples, n_samples]).
57
+ # @param y [Numo::DFloat] (shape: [n_samples]) The target values to be used for fitting the model.
58
+ # @return [SVR] The learned regressor itself.
59
+ def fit(x, y)
60
+ check_sample_array(x)
61
+ check_tvalue_array(y)
62
+ check_sample_tvalue_size(x, y)
63
+ xx = precomputed_kernel? ? add_index_col(x) : x
64
+ @model = Numo::Libsvm.train(xx, y, libsvm_params)
65
+ self
66
+ end
67
+
68
+ # Predict values for samples.
69
+ #
70
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels.
71
+ # If the kernel is 'precomputed', the shape of x must be [n_samples, n_training_samples].
72
+ # @return [Numo::DFloat] (shape: [n_samples]) Predicted value per sample.
73
+ def predict(x)
74
+ check_sample_array(x)
75
+ xx = precomputed_kernel? ? add_index_col(x) : x
76
+ Numo::Libsvm.predict(xx, libsvm_params, @model)
77
+ end
78
+
79
+ # Dump marshal data.
80
+ # @return [Hash] The marshal data about SVR.
81
+ def marshal_dump
82
+ { params: @params,
83
+ model: @model }
84
+ end
85
+
86
+ # Load marshal data.
87
+ # @return [nil]
88
+ def marshal_load(obj)
89
+ @params = obj[:params]
90
+ @model = obj[:model]
91
+ nil
92
+ end
93
+
94
+ # Return the indices of support vectors.
95
+ # @return [Numo::Int32] (shape: [n_support_vectors])
96
+ def support
97
+ @model[:sv_indices]
98
+ end
99
+
100
+ # Return the support_vectors.
101
+ # @return [Numo::DFloat] (shape: [n_support_vectors, n_features])
102
+ def support_vectors
103
+ precomputed_kernel? ? del_index_col(@model[:SV]) : @model[:SV]
104
+ end
105
+
106
+ # Return the number of support vectors.
107
+ # @return [Integer]
108
+ def n_support
109
+ support.size
110
+ end
111
+
112
+ # Return the coefficients of the support vector in decision function.
113
+ # @return [Numo::DFloat] (shape: [1, n_support_vectors])
114
+ def duel_coef
115
+ @model[:sv_coef]
116
+ end
117
+
118
+ # Return the intercepts in decision function.
119
+ # @return [Numo::DFloat] (shape: [1])
120
+ def intercept
121
+ @model[:rho]
122
+ end
123
+
124
+ private
125
+
126
+ def add_index_col(x)
127
+ idx = Numo::Int32.new(x.shape[0]).seq + 1
128
+ Numo::NArray.hstack([idx.expand_dims(1), x])
129
+ end
130
+
131
+ def del_index_col(x)
132
+ x[true, 1..-1].dup
133
+ end
134
+
135
+ def precomputed_kernel?
136
+ @params[:kernel] == 'precomputed'
137
+ end
138
+
139
+ def libsvm_params
140
+ res = @params.merge(svm_type: Numo::Libsvm::SvmType::EPSILON_SVR)
141
+ res[:kernel_type] = case res.delete(:kernel)
142
+ when 'linear'
143
+ Numo::Libsvm::KernelType::LINEAR
144
+ when 'poly'
145
+ Numo::Libsvm::KernelType::POLY
146
+ when 'sigmoid'
147
+ Numo::Libsvm::KernelType::SIGMOID
148
+ when 'precomputed'
149
+ Numo::Libsvm::KernelType::PRECOMPUTED
150
+ else
151
+ Numo::Libsvm::KernelType::RBF
152
+ end
153
+ res[:C] = res.delete(:reg_param)
154
+ res[:p] = res.delete(:epsilon)
155
+ res[:eps] = res.delete(:tol)
156
+ res
157
+ end
158
+ end
159
+ end
160
+ end
@@ -0,0 +1,10 @@
1
+ # frozen_string_literal: true
2
+
3
+ # Rumale is a machine learning library in Ruby.
4
+ module Rumale
5
+ # This module consists of Rumale interfaces for suppor vector machine algorithms with LIBSVM and LIBLINEAR.
6
+ module SVM
7
+ # The version of Rumale-SVM you are using.
8
+ VERSION = '0.1.0'
9
+ end
10
+ end
@@ -0,0 +1,40 @@
1
+ lib = File.expand_path('lib', __dir__)
2
+ $LOAD_PATH.unshift(lib) unless $LOAD_PATH.include?(lib)
3
+ require 'rumale/svm/version'
4
+
5
+ Gem::Specification.new do |spec|
6
+ spec.name = 'rumale-svm'
7
+ spec.version = Rumale::SVM::VERSION
8
+ spec.authors = ['yoshoku']
9
+ spec.email = ['yoshoku@outlook.com']
10
+
11
+ spec.summary = <<~MSG
12
+ Rumale-SVM provides support vector machine algorithms of LIBSVM and LIBLINEAR with Rumale interface.
13
+ MSG
14
+ spec.description = <<~MSG
15
+ Rumale-SVM provides support vector machine algorithms of LIBSVM and LIBLINEAR with Rumale interface.
16
+ MSG
17
+ spec.homepage = 'https://github.com/yoshoku/rumale-svm'
18
+
19
+ spec.metadata['homepage_uri'] = spec.homepage
20
+ spec.metadata['source_code_uri'] = 'https://github.com/yoshoku/rumale-svm'
21
+ spec.metadata['changelog_uri'] = 'https://github.com/yoshoku/rumale-svm/blob/master/CHANGELOG.md'
22
+ spec.metadata['documentation_uri'] = 'https://yoshoku.github.io/rumale-svm/doc/'
23
+
24
+ # Specify which files should be added to the gem when it is released.
25
+ # The `git ls-files -z` loads the files in the RubyGem that have been added into git.
26
+ spec.files = Dir.chdir(File.expand_path(__dir__)) do
27
+ `git ls-files -z`.split("\x0").reject { |f| f.match(%r{^(test|spec|features)/}) }
28
+ end
29
+ spec.bindir = 'exe'
30
+ spec.executables = spec.files.grep(%r{^exe/}) { |f| File.basename(f) }
31
+ spec.require_paths = ['lib']
32
+
33
+ spec.add_runtime_dependency 'numo-liblinear', '~> 1.0'
34
+ spec.add_runtime_dependency 'numo-libsvm', '~> 1.0'
35
+ spec.add_runtime_dependency 'rumale'
36
+ spec.add_development_dependency 'bundler', '~> 2.0'
37
+ spec.add_development_dependency 'coveralls', '~> 0.8'
38
+ spec.add_development_dependency 'rake', '~> 10.0'
39
+ spec.add_development_dependency 'rspec', '~> 3.0'
40
+ end
metadata ADDED
@@ -0,0 +1,171 @@
1
+ --- !ruby/object:Gem::Specification
2
+ name: rumale-svm
3
+ version: !ruby/object:Gem::Version
4
+ version: 0.1.0
5
+ platform: ruby
6
+ authors:
7
+ - yoshoku
8
+ autorequire:
9
+ bindir: exe
10
+ cert_chain: []
11
+ date: 2019-10-22 00:00:00.000000000 Z
12
+ dependencies:
13
+ - !ruby/object:Gem::Dependency
14
+ name: numo-liblinear
15
+ requirement: !ruby/object:Gem::Requirement
16
+ requirements:
17
+ - - "~>"
18
+ - !ruby/object:Gem::Version
19
+ version: '1.0'
20
+ type: :runtime
21
+ prerelease: false
22
+ version_requirements: !ruby/object:Gem::Requirement
23
+ requirements:
24
+ - - "~>"
25
+ - !ruby/object:Gem::Version
26
+ version: '1.0'
27
+ - !ruby/object:Gem::Dependency
28
+ name: numo-libsvm
29
+ requirement: !ruby/object:Gem::Requirement
30
+ requirements:
31
+ - - "~>"
32
+ - !ruby/object:Gem::Version
33
+ version: '1.0'
34
+ type: :runtime
35
+ prerelease: false
36
+ version_requirements: !ruby/object:Gem::Requirement
37
+ requirements:
38
+ - - "~>"
39
+ - !ruby/object:Gem::Version
40
+ version: '1.0'
41
+ - !ruby/object:Gem::Dependency
42
+ name: rumale
43
+ requirement: !ruby/object:Gem::Requirement
44
+ requirements:
45
+ - - ">="
46
+ - !ruby/object:Gem::Version
47
+ version: '0'
48
+ type: :runtime
49
+ prerelease: false
50
+ version_requirements: !ruby/object:Gem::Requirement
51
+ requirements:
52
+ - - ">="
53
+ - !ruby/object:Gem::Version
54
+ version: '0'
55
+ - !ruby/object:Gem::Dependency
56
+ name: bundler
57
+ requirement: !ruby/object:Gem::Requirement
58
+ requirements:
59
+ - - "~>"
60
+ - !ruby/object:Gem::Version
61
+ version: '2.0'
62
+ type: :development
63
+ prerelease: false
64
+ version_requirements: !ruby/object:Gem::Requirement
65
+ requirements:
66
+ - - "~>"
67
+ - !ruby/object:Gem::Version
68
+ version: '2.0'
69
+ - !ruby/object:Gem::Dependency
70
+ name: coveralls
71
+ requirement: !ruby/object:Gem::Requirement
72
+ requirements:
73
+ - - "~>"
74
+ - !ruby/object:Gem::Version
75
+ version: '0.8'
76
+ type: :development
77
+ prerelease: false
78
+ version_requirements: !ruby/object:Gem::Requirement
79
+ requirements:
80
+ - - "~>"
81
+ - !ruby/object:Gem::Version
82
+ version: '0.8'
83
+ - !ruby/object:Gem::Dependency
84
+ name: rake
85
+ requirement: !ruby/object:Gem::Requirement
86
+ requirements:
87
+ - - "~>"
88
+ - !ruby/object:Gem::Version
89
+ version: '10.0'
90
+ type: :development
91
+ prerelease: false
92
+ version_requirements: !ruby/object:Gem::Requirement
93
+ requirements:
94
+ - - "~>"
95
+ - !ruby/object:Gem::Version
96
+ version: '10.0'
97
+ - !ruby/object:Gem::Dependency
98
+ name: rspec
99
+ requirement: !ruby/object:Gem::Requirement
100
+ requirements:
101
+ - - "~>"
102
+ - !ruby/object:Gem::Version
103
+ version: '3.0'
104
+ type: :development
105
+ prerelease: false
106
+ version_requirements: !ruby/object:Gem::Requirement
107
+ requirements:
108
+ - - "~>"
109
+ - !ruby/object:Gem::Version
110
+ version: '3.0'
111
+ description: 'Rumale-SVM provides support vector machine algorithms of LIBSVM and
112
+ LIBLINEAR with Rumale interface.
113
+
114
+ '
115
+ email:
116
+ - yoshoku@outlook.com
117
+ executables: []
118
+ extensions: []
119
+ extra_rdoc_files: []
120
+ files:
121
+ - ".coveralls.yml"
122
+ - ".gitignore"
123
+ - ".rspec"
124
+ - ".travis.yml"
125
+ - CODE_OF_CONDUCT.md
126
+ - Gemfile
127
+ - LICENSE.txt
128
+ - README.md
129
+ - Rakefile
130
+ - bin/console
131
+ - bin/setup
132
+ - lib/rumale/svm.rb
133
+ - lib/rumale/svm/linear_svc.rb
134
+ - lib/rumale/svm/linear_svr.rb
135
+ - lib/rumale/svm/logistic_regression.rb
136
+ - lib/rumale/svm/nu_svc.rb
137
+ - lib/rumale/svm/nu_svr.rb
138
+ - lib/rumale/svm/one_class_svm.rb
139
+ - lib/rumale/svm/svc.rb
140
+ - lib/rumale/svm/svr.rb
141
+ - lib/rumale/svm/version.rb
142
+ - rumale-svm.gemspec
143
+ homepage: https://github.com/yoshoku/rumale-svm
144
+ licenses: []
145
+ metadata:
146
+ homepage_uri: https://github.com/yoshoku/rumale-svm
147
+ source_code_uri: https://github.com/yoshoku/rumale-svm
148
+ changelog_uri: https://github.com/yoshoku/rumale-svm/blob/master/CHANGELOG.md
149
+ documentation_uri: https://yoshoku.github.io/rumale-svm/doc/
150
+ post_install_message:
151
+ rdoc_options: []
152
+ require_paths:
153
+ - lib
154
+ required_ruby_version: !ruby/object:Gem::Requirement
155
+ requirements:
156
+ - - ">="
157
+ - !ruby/object:Gem::Version
158
+ version: '0'
159
+ required_rubygems_version: !ruby/object:Gem::Requirement
160
+ requirements:
161
+ - - ">="
162
+ - !ruby/object:Gem::Version
163
+ version: '0'
164
+ requirements: []
165
+ rubyforge_project:
166
+ rubygems_version: 2.6.14.4
167
+ signing_key:
168
+ specification_version: 4
169
+ summary: Rumale-SVM provides support vector machine algorithms of LIBSVM and LIBLINEAR
170
+ with Rumale interface.
171
+ test_files: []