rumale-svm 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,160 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'numo/libsvm'
4
+ require 'rumale/base/base_estimator'
5
+ require 'rumale/base/regressor'
6
+
7
+ module Rumale
8
+ module SVM
9
+ # SVR is a class that provides Kernel Epsilon-Support Vector Regressor in LIBSVM with Rumale interface.
10
+ #
11
+ # @example
12
+ # estimator = Rumale::SVM::SVR.new(reg_param: 1.0, kernel: 'rbf', gamma: 10.0, random_seed: 1)
13
+ # estimator.fit(training_samples, traininig_target_values)
14
+ # results = estimator.predict(testing_samples)
15
+ class SVR
16
+ include Base::BaseEstimator
17
+ include Base::Regressor
18
+
19
+ # Create a new regressor with Kernel Epsilon-Support Vector Regressor.
20
+ #
21
+ # @param reg_param [Float] The regularization parameter.
22
+ # @param epsilon [Float] The epsilon parameter in loss function of espsilon-svr.
23
+ # @param kernel [String] The type of kernel function ('rbf', 'linear', 'poly', 'sigmoid', and 'precomputed').
24
+ # @param degree [Integer] The degree parameter in polynomial kernel function.
25
+ # @param gamma [Float] The gamma parameter in rbf/poly/sigmoid kernel function.
26
+ # @param coef0 [Float] The coefficient in poly/sigmoid kernel function.
27
+ # @param shrinking [Boolean] The flag indicating whether to use the shrinking heuristics.
28
+ # @param cache_size [Float] The cache memory size in MB.
29
+ # @param tol [Float] The tolerance of termination criterion.
30
+ # @param verbose [Boolean] The flag indicating whether to output learning process message
31
+ # @param random_seed [Integer/Nil] The seed value using to initialize the random generator.
32
+ def initialize(reg_param: 1.0, epsilon: 0.1, kernel: 'rbf', degree: 3, gamma: 1.0, coef0: 0.0,
33
+ shrinking: true, cache_size: 200.0, tol: 1e-3, verbose: false, random_seed: nil)
34
+ check_params_float(reg_param: reg_param, epsilon: epsilon, gamma: gamma, coef0: coef0, cache_size: cache_size, tol: tol)
35
+ check_params_integer(degree: degree)
36
+ check_params_boolean(shrinking: shrinking, verbose: verbose)
37
+ check_params_type_or_nil(Integer, random_seed: random_seed)
38
+ @params = {}
39
+ @params[:reg_param] = reg_param
40
+ @params[:epsilon] = epsilon
41
+ @params[:kernel] = kernel
42
+ @params[:degree] = degree
43
+ @params[:gamma] = gamma
44
+ @params[:coef0] = coef0
45
+ @params[:shrinking] = shrinking
46
+ @params[:cache_size] = cache_size
47
+ @params[:tol] = tol
48
+ @params[:verbose] = verbose
49
+ @params[:random_seed] = random_seed
50
+ @model = nil
51
+ end
52
+
53
+ # Fit the model with given training data.
54
+ #
55
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
56
+ # If the kernel is 'precomputed', x must be a square distance matrix (shape: [n_samples, n_samples]).
57
+ # @param y [Numo::DFloat] (shape: [n_samples]) The target values to be used for fitting the model.
58
+ # @return [SVR] The learned regressor itself.
59
+ def fit(x, y)
60
+ check_sample_array(x)
61
+ check_tvalue_array(y)
62
+ check_sample_tvalue_size(x, y)
63
+ xx = precomputed_kernel? ? add_index_col(x) : x
64
+ @model = Numo::Libsvm.train(xx, y, libsvm_params)
65
+ self
66
+ end
67
+
68
+ # Predict values for samples.
69
+ #
70
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels.
71
+ # If the kernel is 'precomputed', the shape of x must be [n_samples, n_training_samples].
72
+ # @return [Numo::DFloat] (shape: [n_samples]) Predicted value per sample.
73
+ def predict(x)
74
+ check_sample_array(x)
75
+ xx = precomputed_kernel? ? add_index_col(x) : x
76
+ Numo::Libsvm.predict(xx, libsvm_params, @model)
77
+ end
78
+
79
+ # Dump marshal data.
80
+ # @return [Hash] The marshal data about SVR.
81
+ def marshal_dump
82
+ { params: @params,
83
+ model: @model }
84
+ end
85
+
86
+ # Load marshal data.
87
+ # @return [nil]
88
+ def marshal_load(obj)
89
+ @params = obj[:params]
90
+ @model = obj[:model]
91
+ nil
92
+ end
93
+
94
+ # Return the indices of support vectors.
95
+ # @return [Numo::Int32] (shape: [n_support_vectors])
96
+ def support
97
+ @model[:sv_indices]
98
+ end
99
+
100
+ # Return the support_vectors.
101
+ # @return [Numo::DFloat] (shape: [n_support_vectors, n_features])
102
+ def support_vectors
103
+ precomputed_kernel? ? del_index_col(@model[:SV]) : @model[:SV]
104
+ end
105
+
106
+ # Return the number of support vectors.
107
+ # @return [Integer]
108
+ def n_support
109
+ support.size
110
+ end
111
+
112
+ # Return the coefficients of the support vector in decision function.
113
+ # @return [Numo::DFloat] (shape: [1, n_support_vectors])
114
+ def duel_coef
115
+ @model[:sv_coef]
116
+ end
117
+
118
+ # Return the intercepts in decision function.
119
+ # @return [Numo::DFloat] (shape: [1])
120
+ def intercept
121
+ @model[:rho]
122
+ end
123
+
124
+ private
125
+
126
+ def add_index_col(x)
127
+ idx = Numo::Int32.new(x.shape[0]).seq + 1
128
+ Numo::NArray.hstack([idx.expand_dims(1), x])
129
+ end
130
+
131
+ def del_index_col(x)
132
+ x[true, 1..-1].dup
133
+ end
134
+
135
+ def precomputed_kernel?
136
+ @params[:kernel] == 'precomputed'
137
+ end
138
+
139
+ def libsvm_params
140
+ res = @params.merge(svm_type: Numo::Libsvm::SvmType::EPSILON_SVR)
141
+ res[:kernel_type] = case res.delete(:kernel)
142
+ when 'linear'
143
+ Numo::Libsvm::KernelType::LINEAR
144
+ when 'poly'
145
+ Numo::Libsvm::KernelType::POLY
146
+ when 'sigmoid'
147
+ Numo::Libsvm::KernelType::SIGMOID
148
+ when 'precomputed'
149
+ Numo::Libsvm::KernelType::PRECOMPUTED
150
+ else
151
+ Numo::Libsvm::KernelType::RBF
152
+ end
153
+ res[:C] = res.delete(:reg_param)
154
+ res[:p] = res.delete(:epsilon)
155
+ res[:eps] = res.delete(:tol)
156
+ res
157
+ end
158
+ end
159
+ end
160
+ end
@@ -0,0 +1,10 @@
1
+ # frozen_string_literal: true
2
+
3
+ # Rumale is a machine learning library in Ruby.
4
+ module Rumale
5
+ # This module consists of Rumale interfaces for suppor vector machine algorithms with LIBSVM and LIBLINEAR.
6
+ module SVM
7
+ # The version of Rumale-SVM you are using.
8
+ VERSION = '0.1.0'
9
+ end
10
+ end
@@ -0,0 +1,40 @@
1
+ lib = File.expand_path('lib', __dir__)
2
+ $LOAD_PATH.unshift(lib) unless $LOAD_PATH.include?(lib)
3
+ require 'rumale/svm/version'
4
+
5
+ Gem::Specification.new do |spec|
6
+ spec.name = 'rumale-svm'
7
+ spec.version = Rumale::SVM::VERSION
8
+ spec.authors = ['yoshoku']
9
+ spec.email = ['yoshoku@outlook.com']
10
+
11
+ spec.summary = <<~MSG
12
+ Rumale-SVM provides support vector machine algorithms of LIBSVM and LIBLINEAR with Rumale interface.
13
+ MSG
14
+ spec.description = <<~MSG
15
+ Rumale-SVM provides support vector machine algorithms of LIBSVM and LIBLINEAR with Rumale interface.
16
+ MSG
17
+ spec.homepage = 'https://github.com/yoshoku/rumale-svm'
18
+
19
+ spec.metadata['homepage_uri'] = spec.homepage
20
+ spec.metadata['source_code_uri'] = 'https://github.com/yoshoku/rumale-svm'
21
+ spec.metadata['changelog_uri'] = 'https://github.com/yoshoku/rumale-svm/blob/master/CHANGELOG.md'
22
+ spec.metadata['documentation_uri'] = 'https://yoshoku.github.io/rumale-svm/doc/'
23
+
24
+ # Specify which files should be added to the gem when it is released.
25
+ # The `git ls-files -z` loads the files in the RubyGem that have been added into git.
26
+ spec.files = Dir.chdir(File.expand_path(__dir__)) do
27
+ `git ls-files -z`.split("\x0").reject { |f| f.match(%r{^(test|spec|features)/}) }
28
+ end
29
+ spec.bindir = 'exe'
30
+ spec.executables = spec.files.grep(%r{^exe/}) { |f| File.basename(f) }
31
+ spec.require_paths = ['lib']
32
+
33
+ spec.add_runtime_dependency 'numo-liblinear', '~> 1.0'
34
+ spec.add_runtime_dependency 'numo-libsvm', '~> 1.0'
35
+ spec.add_runtime_dependency 'rumale'
36
+ spec.add_development_dependency 'bundler', '~> 2.0'
37
+ spec.add_development_dependency 'coveralls', '~> 0.8'
38
+ spec.add_development_dependency 'rake', '~> 10.0'
39
+ spec.add_development_dependency 'rspec', '~> 3.0'
40
+ end
metadata ADDED
@@ -0,0 +1,171 @@
1
+ --- !ruby/object:Gem::Specification
2
+ name: rumale-svm
3
+ version: !ruby/object:Gem::Version
4
+ version: 0.1.0
5
+ platform: ruby
6
+ authors:
7
+ - yoshoku
8
+ autorequire:
9
+ bindir: exe
10
+ cert_chain: []
11
+ date: 2019-10-22 00:00:00.000000000 Z
12
+ dependencies:
13
+ - !ruby/object:Gem::Dependency
14
+ name: numo-liblinear
15
+ requirement: !ruby/object:Gem::Requirement
16
+ requirements:
17
+ - - "~>"
18
+ - !ruby/object:Gem::Version
19
+ version: '1.0'
20
+ type: :runtime
21
+ prerelease: false
22
+ version_requirements: !ruby/object:Gem::Requirement
23
+ requirements:
24
+ - - "~>"
25
+ - !ruby/object:Gem::Version
26
+ version: '1.0'
27
+ - !ruby/object:Gem::Dependency
28
+ name: numo-libsvm
29
+ requirement: !ruby/object:Gem::Requirement
30
+ requirements:
31
+ - - "~>"
32
+ - !ruby/object:Gem::Version
33
+ version: '1.0'
34
+ type: :runtime
35
+ prerelease: false
36
+ version_requirements: !ruby/object:Gem::Requirement
37
+ requirements:
38
+ - - "~>"
39
+ - !ruby/object:Gem::Version
40
+ version: '1.0'
41
+ - !ruby/object:Gem::Dependency
42
+ name: rumale
43
+ requirement: !ruby/object:Gem::Requirement
44
+ requirements:
45
+ - - ">="
46
+ - !ruby/object:Gem::Version
47
+ version: '0'
48
+ type: :runtime
49
+ prerelease: false
50
+ version_requirements: !ruby/object:Gem::Requirement
51
+ requirements:
52
+ - - ">="
53
+ - !ruby/object:Gem::Version
54
+ version: '0'
55
+ - !ruby/object:Gem::Dependency
56
+ name: bundler
57
+ requirement: !ruby/object:Gem::Requirement
58
+ requirements:
59
+ - - "~>"
60
+ - !ruby/object:Gem::Version
61
+ version: '2.0'
62
+ type: :development
63
+ prerelease: false
64
+ version_requirements: !ruby/object:Gem::Requirement
65
+ requirements:
66
+ - - "~>"
67
+ - !ruby/object:Gem::Version
68
+ version: '2.0'
69
+ - !ruby/object:Gem::Dependency
70
+ name: coveralls
71
+ requirement: !ruby/object:Gem::Requirement
72
+ requirements:
73
+ - - "~>"
74
+ - !ruby/object:Gem::Version
75
+ version: '0.8'
76
+ type: :development
77
+ prerelease: false
78
+ version_requirements: !ruby/object:Gem::Requirement
79
+ requirements:
80
+ - - "~>"
81
+ - !ruby/object:Gem::Version
82
+ version: '0.8'
83
+ - !ruby/object:Gem::Dependency
84
+ name: rake
85
+ requirement: !ruby/object:Gem::Requirement
86
+ requirements:
87
+ - - "~>"
88
+ - !ruby/object:Gem::Version
89
+ version: '10.0'
90
+ type: :development
91
+ prerelease: false
92
+ version_requirements: !ruby/object:Gem::Requirement
93
+ requirements:
94
+ - - "~>"
95
+ - !ruby/object:Gem::Version
96
+ version: '10.0'
97
+ - !ruby/object:Gem::Dependency
98
+ name: rspec
99
+ requirement: !ruby/object:Gem::Requirement
100
+ requirements:
101
+ - - "~>"
102
+ - !ruby/object:Gem::Version
103
+ version: '3.0'
104
+ type: :development
105
+ prerelease: false
106
+ version_requirements: !ruby/object:Gem::Requirement
107
+ requirements:
108
+ - - "~>"
109
+ - !ruby/object:Gem::Version
110
+ version: '3.0'
111
+ description: 'Rumale-SVM provides support vector machine algorithms of LIBSVM and
112
+ LIBLINEAR with Rumale interface.
113
+
114
+ '
115
+ email:
116
+ - yoshoku@outlook.com
117
+ executables: []
118
+ extensions: []
119
+ extra_rdoc_files: []
120
+ files:
121
+ - ".coveralls.yml"
122
+ - ".gitignore"
123
+ - ".rspec"
124
+ - ".travis.yml"
125
+ - CODE_OF_CONDUCT.md
126
+ - Gemfile
127
+ - LICENSE.txt
128
+ - README.md
129
+ - Rakefile
130
+ - bin/console
131
+ - bin/setup
132
+ - lib/rumale/svm.rb
133
+ - lib/rumale/svm/linear_svc.rb
134
+ - lib/rumale/svm/linear_svr.rb
135
+ - lib/rumale/svm/logistic_regression.rb
136
+ - lib/rumale/svm/nu_svc.rb
137
+ - lib/rumale/svm/nu_svr.rb
138
+ - lib/rumale/svm/one_class_svm.rb
139
+ - lib/rumale/svm/svc.rb
140
+ - lib/rumale/svm/svr.rb
141
+ - lib/rumale/svm/version.rb
142
+ - rumale-svm.gemspec
143
+ homepage: https://github.com/yoshoku/rumale-svm
144
+ licenses: []
145
+ metadata:
146
+ homepage_uri: https://github.com/yoshoku/rumale-svm
147
+ source_code_uri: https://github.com/yoshoku/rumale-svm
148
+ changelog_uri: https://github.com/yoshoku/rumale-svm/blob/master/CHANGELOG.md
149
+ documentation_uri: https://yoshoku.github.io/rumale-svm/doc/
150
+ post_install_message:
151
+ rdoc_options: []
152
+ require_paths:
153
+ - lib
154
+ required_ruby_version: !ruby/object:Gem::Requirement
155
+ requirements:
156
+ - - ">="
157
+ - !ruby/object:Gem::Version
158
+ version: '0'
159
+ required_rubygems_version: !ruby/object:Gem::Requirement
160
+ requirements:
161
+ - - ">="
162
+ - !ruby/object:Gem::Version
163
+ version: '0'
164
+ requirements: []
165
+ rubyforge_project:
166
+ rubygems_version: 2.6.14.4
167
+ signing_key:
168
+ specification_version: 4
169
+ summary: Rumale-SVM provides support vector machine algorithms of LIBSVM and LIBLINEAR
170
+ with Rumale interface.
171
+ test_files: []