rumale-svm 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/.coveralls.yml +1 -0
- data/.gitignore +18 -0
- data/.rspec +3 -0
- data/.travis.yml +13 -0
- data/CODE_OF_CONDUCT.md +74 -0
- data/Gemfile +4 -0
- data/LICENSE.txt +27 -0
- data/README.md +92 -0
- data/Rakefile +6 -0
- data/bin/console +14 -0
- data/bin/setup +8 -0
- data/lib/rumale/svm.rb +11 -0
- data/lib/rumale/svm/linear_svc.rb +238 -0
- data/lib/rumale/svm/linear_svr.rb +150 -0
- data/lib/rumale/svm/logistic_regression.rb +190 -0
- data/lib/rumale/svm/nu_svc.rb +193 -0
- data/lib/rumale/svm/nu_svr.rb +156 -0
- data/lib/rumale/svm/one_class_svm.rb +150 -0
- data/lib/rumale/svm/svc.rb +194 -0
- data/lib/rumale/svm/svr.rb +160 -0
- data/lib/rumale/svm/version.rb +10 -0
- data/rumale-svm.gemspec +40 -0
- metadata +171 -0
checksums.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
1
|
+
---
|
2
|
+
SHA1:
|
3
|
+
metadata.gz: 40bfeaba6ea5d4b31d75b3c7c7573b121aa2390a
|
4
|
+
data.tar.gz: a59abc913381101d82c488e145c9e375d2e55f41
|
5
|
+
SHA512:
|
6
|
+
metadata.gz: 333cce8eb4f15d6a23e7cc31b2cd5d122e3f59ce61971e0fa9d8c01144a4e67b44139a367132752d08c269b7564330503cb13558f93accc37e520f1542688c06
|
7
|
+
data.tar.gz: c5fc3f8c8287603369c1b801c6615644a13b75fdd75fb21f3470aa2f067d8d0800cac59b9b7fa9984405cf55a8edcf9e8e4ed46f2a65e0bf8e9791372fd5992c
|
data/.coveralls.yml
ADDED
@@ -0,0 +1 @@
|
|
1
|
+
service_name: travis-ci
|
data/.gitignore
ADDED
data/.rspec
ADDED
data/.travis.yml
ADDED
data/CODE_OF_CONDUCT.md
ADDED
@@ -0,0 +1,74 @@
|
|
1
|
+
# Contributor Covenant Code of Conduct
|
2
|
+
|
3
|
+
## Our Pledge
|
4
|
+
|
5
|
+
In the interest of fostering an open and welcoming environment, we as
|
6
|
+
contributors and maintainers pledge to making participation in our project and
|
7
|
+
our community a harassment-free experience for everyone, regardless of age, body
|
8
|
+
size, disability, ethnicity, gender identity and expression, level of experience,
|
9
|
+
nationality, personal appearance, race, religion, or sexual identity and
|
10
|
+
orientation.
|
11
|
+
|
12
|
+
## Our Standards
|
13
|
+
|
14
|
+
Examples of behavior that contributes to creating a positive environment
|
15
|
+
include:
|
16
|
+
|
17
|
+
* Using welcoming and inclusive language
|
18
|
+
* Being respectful of differing viewpoints and experiences
|
19
|
+
* Gracefully accepting constructive criticism
|
20
|
+
* Focusing on what is best for the community
|
21
|
+
* Showing empathy towards other community members
|
22
|
+
|
23
|
+
Examples of unacceptable behavior by participants include:
|
24
|
+
|
25
|
+
* The use of sexualized language or imagery and unwelcome sexual attention or
|
26
|
+
advances
|
27
|
+
* Trolling, insulting/derogatory comments, and personal or political attacks
|
28
|
+
* Public or private harassment
|
29
|
+
* Publishing others' private information, such as a physical or electronic
|
30
|
+
address, without explicit permission
|
31
|
+
* Other conduct which could reasonably be considered inappropriate in a
|
32
|
+
professional setting
|
33
|
+
|
34
|
+
## Our Responsibilities
|
35
|
+
|
36
|
+
Project maintainers are responsible for clarifying the standards of acceptable
|
37
|
+
behavior and are expected to take appropriate and fair corrective action in
|
38
|
+
response to any instances of unacceptable behavior.
|
39
|
+
|
40
|
+
Project maintainers have the right and responsibility to remove, edit, or
|
41
|
+
reject comments, commits, code, wiki edits, issues, and other contributions
|
42
|
+
that are not aligned to this Code of Conduct, or to ban temporarily or
|
43
|
+
permanently any contributor for other behaviors that they deem inappropriate,
|
44
|
+
threatening, offensive, or harmful.
|
45
|
+
|
46
|
+
## Scope
|
47
|
+
|
48
|
+
This Code of Conduct applies both within project spaces and in public spaces
|
49
|
+
when an individual is representing the project or its community. Examples of
|
50
|
+
representing a project or community include using an official project e-mail
|
51
|
+
address, posting via an official social media account, or acting as an appointed
|
52
|
+
representative at an online or offline event. Representation of a project may be
|
53
|
+
further defined and clarified by project maintainers.
|
54
|
+
|
55
|
+
## Enforcement
|
56
|
+
|
57
|
+
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
58
|
+
reported by contacting the project team at yoshoku@outlook.com. All
|
59
|
+
complaints will be reviewed and investigated and will result in a response that
|
60
|
+
is deemed necessary and appropriate to the circumstances. The project team is
|
61
|
+
obligated to maintain confidentiality with regard to the reporter of an incident.
|
62
|
+
Further details of specific enforcement policies may be posted separately.
|
63
|
+
|
64
|
+
Project maintainers who do not follow or enforce the Code of Conduct in good
|
65
|
+
faith may face temporary or permanent repercussions as determined by other
|
66
|
+
members of the project's leadership.
|
67
|
+
|
68
|
+
## Attribution
|
69
|
+
|
70
|
+
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
|
71
|
+
available at [http://contributor-covenant.org/version/1/4][version]
|
72
|
+
|
73
|
+
[homepage]: http://contributor-covenant.org
|
74
|
+
[version]: http://contributor-covenant.org/version/1/4/
|
data/Gemfile
ADDED
data/LICENSE.txt
ADDED
@@ -0,0 +1,27 @@
|
|
1
|
+
Copyright (c) 2019 Atsushi Tatsuma
|
2
|
+
All rights reserved.
|
3
|
+
|
4
|
+
Redistribution and use in source and binary forms, with or without
|
5
|
+
modification, are permitted provided that the following conditions are met:
|
6
|
+
|
7
|
+
* Redistributions of source code must retain the above copyright notice, this
|
8
|
+
list of conditions and the following disclaimer.
|
9
|
+
|
10
|
+
* Redistributions in binary form must reproduce the above copyright notice,
|
11
|
+
this list of conditions and the following disclaimer in the documentation
|
12
|
+
and/or other materials provided with the distribution.
|
13
|
+
|
14
|
+
* Neither the name of the copyright holder nor the names of its
|
15
|
+
contributors may be used to endorse or promote products derived from
|
16
|
+
this software without specific prior written permission.
|
17
|
+
|
18
|
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
19
|
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
20
|
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
21
|
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
22
|
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
23
|
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
24
|
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
25
|
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
26
|
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
27
|
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
data/README.md
ADDED
@@ -0,0 +1,92 @@
|
|
1
|
+
# Rumale::SVM
|
2
|
+
|
3
|
+
[](https://travis-ci.org/yoshoku/rumale-svm)
|
4
|
+
[](https://coveralls.io/github/yoshoku/rumale-svm?branch=master)
|
5
|
+
[](https://badge.fury.io/rb/rumale-svm)
|
6
|
+
[](https://github.com/yoshoku/rumale-svm/blob/master/LICENSE.txt)
|
7
|
+
[](https://yoshoku.github.io/rumale-svm/doc/)
|
8
|
+
|
9
|
+
Rumale::SVM provides support vector machine algorithms in
|
10
|
+
[LIBSVM](https://www.csie.ntu.edu.tw/~cjlin/libsvm/) and [LIBLINEAR](https://www.csie.ntu.edu.tw/~cjlin/liblinear/)
|
11
|
+
with [Rumale](https://github.com/yoshoku/rumale) interface.
|
12
|
+
Many machine learning libraries use LIBSVM and LIBLINEAR as background libraries of support vector machine algorithms.
|
13
|
+
On the other hand, Rumale implements support vector machine algorithms based on the mini-batch stochastic gradient descent method
|
14
|
+
implemented in Ruby.
|
15
|
+
Rumale::SVM adds the functions of support vector machine similar to general machine learning libraries to Rumale.
|
16
|
+
|
17
|
+
## Installation
|
18
|
+
|
19
|
+
Add this line to your application's Gemfile:
|
20
|
+
|
21
|
+
```ruby
|
22
|
+
gem 'rumale-svm'
|
23
|
+
```
|
24
|
+
|
25
|
+
And then execute:
|
26
|
+
|
27
|
+
$ bundle
|
28
|
+
|
29
|
+
Or install it yourself as:
|
30
|
+
|
31
|
+
$ gem install rumale-svm
|
32
|
+
|
33
|
+
## Usage
|
34
|
+
Download pendigits dataset from [LIBSVM DATA](https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/) web page.
|
35
|
+
|
36
|
+
```sh
|
37
|
+
$ wget https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/pendigits
|
38
|
+
$ wget https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/pendigits.t
|
39
|
+
```
|
40
|
+
|
41
|
+
Training linear support vector classifier.
|
42
|
+
|
43
|
+
```ruby
|
44
|
+
require 'rumale/svm'
|
45
|
+
require 'rumale/dataset'
|
46
|
+
|
47
|
+
samples, labels = Rumale::Dataset.load_libsvm_file('pendigits')
|
48
|
+
svc = Rumale::SVM::LinearSVC.new(random_seed: 1)
|
49
|
+
svc.fit(samples, labels)
|
50
|
+
|
51
|
+
File.open('svc.dat', 'wb') { |f| f.write(Marshal.dump(svc)) }
|
52
|
+
```
|
53
|
+
|
54
|
+
Evaluate classifiction accuracy on testing datase.
|
55
|
+
|
56
|
+
```ruby
|
57
|
+
require 'rumale/svm'
|
58
|
+
require 'rumale/dataset'
|
59
|
+
|
60
|
+
samples, labels = Rumale::Dataset.load_libsvm_file('pendigits.t')
|
61
|
+
svc = Marshal.load(File.binread('svc.dat'))
|
62
|
+
|
63
|
+
puts "Accuracy: #{svc.score(samples, labels).round(3)}"
|
64
|
+
```
|
65
|
+
|
66
|
+
Execution result.
|
67
|
+
|
68
|
+
```sh
|
69
|
+
$ ruby rumale_svm_train.rb
|
70
|
+
$ ls svc.dat
|
71
|
+
svc.dat
|
72
|
+
$ ruby rumale_svm_test.rb
|
73
|
+
Accuracy: 0.835
|
74
|
+
```
|
75
|
+
|
76
|
+
## Development
|
77
|
+
|
78
|
+
After checking out the repo, run `bin/setup` to install dependencies. Then, run `rake spec` to run the tests. You can also run `bin/console` for an interactive prompt that will allow you to experiment.
|
79
|
+
|
80
|
+
To install this gem onto your local machine, run `bundle exec rake install`. To release a new version, update the version number in `version.rb`, and then run `bundle exec rake release`, which will create a git tag for the version, push git commits and tags, and push the `.gem` file to [rubygems.org](https://rubygems.org).
|
81
|
+
|
82
|
+
## Contributing
|
83
|
+
|
84
|
+
Bug reports and pull requests are welcome on GitHub at https://github.com/yoshoku/rumale-svm. This project is intended to be a safe, welcoming space for collaboration, and contributors are expected to adhere to the [Contributor Covenant](http://contributor-covenant.org) code of conduct.
|
85
|
+
|
86
|
+
## License
|
87
|
+
|
88
|
+
The gem is available as open source under the terms of the [BSD-3-Clause License](https://opensource.org/licenses/BSD-3-Clause).
|
89
|
+
|
90
|
+
## Code of Conduct
|
91
|
+
|
92
|
+
Everyone interacting in the Rumale::Svm project’s codebases, issue trackers, chat rooms and mailing lists is expected to follow the [code of conduct](https://github.com/yoshoku/rumale-svm/blob/master/CODE_OF_CONDUCT.md).
|
data/Rakefile
ADDED
data/bin/console
ADDED
@@ -0,0 +1,14 @@
|
|
1
|
+
#!/usr/bin/env ruby
|
2
|
+
|
3
|
+
require "bundler/setup"
|
4
|
+
require "rumale/svm"
|
5
|
+
|
6
|
+
# You can add fixtures and/or initialization code here to make experimenting
|
7
|
+
# with your gem easier. You can also use a different console, if you like.
|
8
|
+
|
9
|
+
# (If you use this, don't forget to add pry to your Gemfile!)
|
10
|
+
# require "pry"
|
11
|
+
# Pry.start
|
12
|
+
|
13
|
+
require "irb"
|
14
|
+
IRB.start(__FILE__)
|
data/bin/setup
ADDED
data/lib/rumale/svm.rb
ADDED
@@ -0,0 +1,11 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'rumale/svm/version'
|
4
|
+
require 'rumale/svm/svc'
|
5
|
+
require 'rumale/svm/svr'
|
6
|
+
require 'rumale/svm/nu_svc'
|
7
|
+
require 'rumale/svm/nu_svr'
|
8
|
+
require 'rumale/svm/one_class_svm'
|
9
|
+
require 'rumale/svm/linear_svc'
|
10
|
+
require 'rumale/svm/linear_svr'
|
11
|
+
require 'rumale/svm/logistic_regression'
|
@@ -0,0 +1,238 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'numo/liblinear'
|
4
|
+
require 'rumale/base/base_estimator'
|
5
|
+
require 'rumale/base/classifier'
|
6
|
+
require 'rumale/probabilistic_output'
|
7
|
+
|
8
|
+
module Rumale
|
9
|
+
module SVM
|
10
|
+
# LinearSVC is a class that provides Support Vector Classifier in LIBLINEAR with Rumale interface.
|
11
|
+
#
|
12
|
+
# @example
|
13
|
+
# estimator = Rumale::SVM::LinearSVC.new(penalty: 'l2', loss: 'squared_hinge', reg_param: 1.0, random_seed: 1)
|
14
|
+
# estimator.fit(training_samples, traininig_labels)
|
15
|
+
# results = estimator.predict(testing_samples)
|
16
|
+
class LinearSVC
|
17
|
+
include Base::BaseEstimator
|
18
|
+
include Base::Classifier
|
19
|
+
|
20
|
+
# Return the weight vector for LinearSVC.
|
21
|
+
# @return [Numo::DFloat] (shape: [n_classes, n_features])
|
22
|
+
attr_reader :weight_vec
|
23
|
+
|
24
|
+
# Return the bias term (a.k.a. intercept) for LinearSVC.
|
25
|
+
# @return [Numo::DFloat] (shape: [n_classes])
|
26
|
+
attr_reader :bias_term
|
27
|
+
|
28
|
+
# Create a new classifier with Support Vector Classifier.
|
29
|
+
#
|
30
|
+
# @param penalty [String] The type of norm used in the penalization ('l2' or 'l1').
|
31
|
+
# @param loss [String] The type of loss function ('squared_hinge' or 'hinge').
|
32
|
+
# This parameter is ignored if penalty = 'l1'.
|
33
|
+
# @param dual [Boolean] The flag indicating whether to solve dual optimization problem.
|
34
|
+
# When n_samples > n_features, dual = false is more preferable.
|
35
|
+
# This parameter is ignored if loss = 'hinge'.
|
36
|
+
# @param reg_param [Float] The regularization parameter.
|
37
|
+
# @param fit_bias [Boolean] The flag indicating whether to fit the bias term.
|
38
|
+
# @param bias_scale [Float] The scale of the bias term.
|
39
|
+
# This parameter is ignored if fit_bias = false.
|
40
|
+
# @param probability [Boolean] The flag indicating whether to train the parameter for probability estimation.
|
41
|
+
# @param tol [Float] The tolerance of termination criterion.
|
42
|
+
# @param verbose [Boolean] The flag indicating whether to output learning process message
|
43
|
+
# @param random_seed [Integer/Nil] The seed value using to initialize the random generator.
|
44
|
+
def initialize(penalty: 'l2', loss: 'squared_hinge', dual: true, reg_param: 1.0,
|
45
|
+
fit_bias: true, bias_scale: 1.0, probability: false,
|
46
|
+
tol: 1e-3, verbose: false, random_seed: nil)
|
47
|
+
check_params_string(penalty: penalty, loss: loss)
|
48
|
+
check_params_float(reg_param: reg_param, bias_scale: bias_scale, tol: tol)
|
49
|
+
check_params_boolean(dual: dual, fit_bias: fit_bias, probability: probability, verbose: verbose)
|
50
|
+
check_params_type_or_nil(Integer, random_seed: random_seed)
|
51
|
+
@params = {}
|
52
|
+
@params[:penalty] = penalty == 'l1' ? 'l1' : 'l2'
|
53
|
+
@params[:loss] = loss == 'hinge' ? 'hinge' : 'squared_hinge'
|
54
|
+
@params[:dual] = dual
|
55
|
+
@params[:reg_param] = reg_param
|
56
|
+
@params[:fit_bias] = fit_bias
|
57
|
+
@params[:bias_scale] = bias_scale
|
58
|
+
@params[:probability] = probability
|
59
|
+
@params[:tol] = tol
|
60
|
+
@params[:verbose] = verbose
|
61
|
+
@params[:random_seed] = random_seed
|
62
|
+
@model = nil
|
63
|
+
@weight_vec = nil
|
64
|
+
@bias_term = nil
|
65
|
+
@prob_param = nil
|
66
|
+
end
|
67
|
+
|
68
|
+
# Fit the model with given training data.
|
69
|
+
#
|
70
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
|
71
|
+
# @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
|
72
|
+
# @return [LinearSVC] The learned classifier itself.
|
73
|
+
def fit(x, y)
|
74
|
+
check_sample_array(x)
|
75
|
+
check_label_array(y)
|
76
|
+
check_sample_label_size(x, y)
|
77
|
+
xx = fit_bias? ? expand_feature(x) : x
|
78
|
+
@model = Numo::Liblinear.train(xx, y, liblinear_params)
|
79
|
+
@weight_vec, @bias_term = weight_and_bias(@model[:w])
|
80
|
+
@prob_param = proba_model(decision_function(x), y)
|
81
|
+
self
|
82
|
+
end
|
83
|
+
|
84
|
+
# Calculate confidence scores for samples.
|
85
|
+
#
|
86
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to compute the scores.
|
87
|
+
# @return [Numo::DFloat] (shape: [n_samples, n_classes]) Confidence score per sample.
|
88
|
+
def decision_function(x)
|
89
|
+
check_sample_array(x)
|
90
|
+
xx = fit_bias? ? expand_feature(x) : x
|
91
|
+
Numo::Liblinear.decision_function(xx, liblinear_params, @model)
|
92
|
+
end
|
93
|
+
|
94
|
+
# Predict class labels for samples.
|
95
|
+
#
|
96
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels.
|
97
|
+
# @return [Numo::Int32] (shape: [n_samples]) Predicted class label per sample.
|
98
|
+
def predict(x)
|
99
|
+
check_sample_array(x)
|
100
|
+
xx = fit_bias? ? expand_feature(x) : x
|
101
|
+
Numo::Int32.cast(Numo::Liblinear.predict(xx, liblinear_params, @model))
|
102
|
+
end
|
103
|
+
|
104
|
+
# Predict class probability for samples.
|
105
|
+
# This method works correctly only if the probability parameter is true.
|
106
|
+
#
|
107
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the probailities.
|
108
|
+
# @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted probability of each class per sample.
|
109
|
+
def predict_proba(x)
|
110
|
+
check_sample_array(x)
|
111
|
+
if binary_class?
|
112
|
+
probs = Numo::DFloat.zeros(x.shape[0], 2)
|
113
|
+
probs[true, 0] = 1.0 / (Numo::NMath.exp(@prob_param[0] * decision_function(x) + @prob_param[1]) + 1.0)
|
114
|
+
probs[true, 1] = 1.0 - probs[true, 0]
|
115
|
+
else
|
116
|
+
probs = 1.0 / (Numo::NMath.exp(@prob_param[true, 0] * decision_function(x) + @prob_param[true, 1]) + 1.0)
|
117
|
+
probs = (probs.transpose / probs.sum(axis: 1)).transpose.dup
|
118
|
+
end
|
119
|
+
probs
|
120
|
+
end
|
121
|
+
|
122
|
+
# Dump marshal data.
|
123
|
+
# @return [Hash] The marshal data about LinearSVC.
|
124
|
+
def marshal_dump
|
125
|
+
{ params: @params,
|
126
|
+
model: @model,
|
127
|
+
weight_vec: @weight_vec,
|
128
|
+
bias_term: @bias_term,
|
129
|
+
prob_param: @prob_param }
|
130
|
+
end
|
131
|
+
|
132
|
+
# Load marshal data.
|
133
|
+
# @return [nil]
|
134
|
+
def marshal_load(obj)
|
135
|
+
@params = obj[:params]
|
136
|
+
@model = obj[:model]
|
137
|
+
@weight_vec = obj[:weight_vec]
|
138
|
+
@bias_term = obj[:bias_term]
|
139
|
+
@prob_param = obj[:prob_param]
|
140
|
+
nil
|
141
|
+
end
|
142
|
+
|
143
|
+
private
|
144
|
+
|
145
|
+
def expand_feature(x)
|
146
|
+
n_samples = x.shape[0]
|
147
|
+
Numo::NArray.hstack([x, Numo::DFloat.ones([n_samples, 1]) * bias_scale])
|
148
|
+
end
|
149
|
+
|
150
|
+
def weight_and_bias(base_weight)
|
151
|
+
if binary_class?
|
152
|
+
bias_vec = 0.0
|
153
|
+
weight_mat = base_weight.dup
|
154
|
+
if fit_bias?
|
155
|
+
bias_vec = weight_mat[-1]
|
156
|
+
weight_mat = weight_mat[0...-1].dup
|
157
|
+
end
|
158
|
+
else
|
159
|
+
bias_vec = Numo::DFloat.zeros(n_classes)
|
160
|
+
weight_mat = base_weight.reshape(n_features, n_classes).transpose.dup
|
161
|
+
if fit_bias?
|
162
|
+
bias_vec = weight_mat[true, -1].dup
|
163
|
+
weight_mat = weight_mat[true, 0...-1].dup
|
164
|
+
end
|
165
|
+
end
|
166
|
+
[weight_mat, bias_vec]
|
167
|
+
end
|
168
|
+
|
169
|
+
def proba_model(df, y)
|
170
|
+
res = binary_class? ? Numo::DFloat[1, 0] : Numo::DFloat.cast([[1, 0]] * n_classes)
|
171
|
+
return res unless fit_probability?
|
172
|
+
|
173
|
+
if binary_class?
|
174
|
+
bin_y = Numo::Int32.cast(y.eq(labels[0])) * 2 - 1
|
175
|
+
res = Rumale::ProbabilisticOutput.fit_sigmoid(df, bin_y)
|
176
|
+
else
|
177
|
+
labels.each_with_index do |c, n|
|
178
|
+
bin_y = Numo::Int32.cast(y.eq(c)) * 2 - 1
|
179
|
+
res[n, true] = Rumale::ProbabilisticOutput.fit_sigmoid(df[true, n], bin_y)
|
180
|
+
end
|
181
|
+
end
|
182
|
+
res
|
183
|
+
end
|
184
|
+
|
185
|
+
def liblinear_params
|
186
|
+
res = {}
|
187
|
+
res[:solver_type] = solver_type
|
188
|
+
res[:eps] = @params[:tol]
|
189
|
+
res[:C] = @params[:reg_param]
|
190
|
+
res[:verbose] = @params[:verbose]
|
191
|
+
res[:random_seed] = @params[:random_seed]
|
192
|
+
res
|
193
|
+
end
|
194
|
+
|
195
|
+
def solver_type
|
196
|
+
return Numo::Liblinear::SolverType::L1R_L2LOSS_SVC if @params[:penalty] == 'l1'
|
197
|
+
|
198
|
+
if @params[:loss] == 'squared_hinge'
|
199
|
+
if @params[:dual]
|
200
|
+
Numo::Liblinear::SolverType::L2R_L2LOSS_SVC_DUAL
|
201
|
+
else
|
202
|
+
Numo::Liblinear::SolverType::L2R_L2LOSS_SVC
|
203
|
+
end
|
204
|
+
else
|
205
|
+
Numo::Liblinear::SolverType::L2R_L1LOSS_SVC_DUAL
|
206
|
+
end
|
207
|
+
end
|
208
|
+
|
209
|
+
def binary_class?
|
210
|
+
@model[:nr_class] == 2
|
211
|
+
end
|
212
|
+
|
213
|
+
def fit_probability?
|
214
|
+
@params[:probability]
|
215
|
+
end
|
216
|
+
|
217
|
+
def fit_bias?
|
218
|
+
@params[:fit_bias]
|
219
|
+
end
|
220
|
+
|
221
|
+
def bias_scale
|
222
|
+
@params[:bias_scale]
|
223
|
+
end
|
224
|
+
|
225
|
+
def n_classes
|
226
|
+
@model[:nr_class]
|
227
|
+
end
|
228
|
+
|
229
|
+
def n_features
|
230
|
+
@model[:nr_feature]
|
231
|
+
end
|
232
|
+
|
233
|
+
def labels
|
234
|
+
@model[:label]
|
235
|
+
end
|
236
|
+
end
|
237
|
+
end
|
238
|
+
end
|