rumale-svm 0.1.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/.coveralls.yml +1 -0
- data/.gitignore +18 -0
- data/.rspec +3 -0
- data/.travis.yml +13 -0
- data/CODE_OF_CONDUCT.md +74 -0
- data/Gemfile +4 -0
- data/LICENSE.txt +27 -0
- data/README.md +92 -0
- data/Rakefile +6 -0
- data/bin/console +14 -0
- data/bin/setup +8 -0
- data/lib/rumale/svm.rb +11 -0
- data/lib/rumale/svm/linear_svc.rb +238 -0
- data/lib/rumale/svm/linear_svr.rb +150 -0
- data/lib/rumale/svm/logistic_regression.rb +190 -0
- data/lib/rumale/svm/nu_svc.rb +193 -0
- data/lib/rumale/svm/nu_svr.rb +156 -0
- data/lib/rumale/svm/one_class_svm.rb +150 -0
- data/lib/rumale/svm/svc.rb +194 -0
- data/lib/rumale/svm/svr.rb +160 -0
- data/lib/rumale/svm/version.rb +10 -0
- data/rumale-svm.gemspec +40 -0
- metadata +171 -0
checksums.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
1
|
+
---
|
2
|
+
SHA1:
|
3
|
+
metadata.gz: 40bfeaba6ea5d4b31d75b3c7c7573b121aa2390a
|
4
|
+
data.tar.gz: a59abc913381101d82c488e145c9e375d2e55f41
|
5
|
+
SHA512:
|
6
|
+
metadata.gz: 333cce8eb4f15d6a23e7cc31b2cd5d122e3f59ce61971e0fa9d8c01144a4e67b44139a367132752d08c269b7564330503cb13558f93accc37e520f1542688c06
|
7
|
+
data.tar.gz: c5fc3f8c8287603369c1b801c6615644a13b75fdd75fb21f3470aa2f067d8d0800cac59b9b7fa9984405cf55a8edcf9e8e4ed46f2a65e0bf8e9791372fd5992c
|
data/.coveralls.yml
ADDED
@@ -0,0 +1 @@
|
|
1
|
+
service_name: travis-ci
|
data/.gitignore
ADDED
data/.rspec
ADDED
data/.travis.yml
ADDED
data/CODE_OF_CONDUCT.md
ADDED
@@ -0,0 +1,74 @@
|
|
1
|
+
# Contributor Covenant Code of Conduct
|
2
|
+
|
3
|
+
## Our Pledge
|
4
|
+
|
5
|
+
In the interest of fostering an open and welcoming environment, we as
|
6
|
+
contributors and maintainers pledge to making participation in our project and
|
7
|
+
our community a harassment-free experience for everyone, regardless of age, body
|
8
|
+
size, disability, ethnicity, gender identity and expression, level of experience,
|
9
|
+
nationality, personal appearance, race, religion, or sexual identity and
|
10
|
+
orientation.
|
11
|
+
|
12
|
+
## Our Standards
|
13
|
+
|
14
|
+
Examples of behavior that contributes to creating a positive environment
|
15
|
+
include:
|
16
|
+
|
17
|
+
* Using welcoming and inclusive language
|
18
|
+
* Being respectful of differing viewpoints and experiences
|
19
|
+
* Gracefully accepting constructive criticism
|
20
|
+
* Focusing on what is best for the community
|
21
|
+
* Showing empathy towards other community members
|
22
|
+
|
23
|
+
Examples of unacceptable behavior by participants include:
|
24
|
+
|
25
|
+
* The use of sexualized language or imagery and unwelcome sexual attention or
|
26
|
+
advances
|
27
|
+
* Trolling, insulting/derogatory comments, and personal or political attacks
|
28
|
+
* Public or private harassment
|
29
|
+
* Publishing others' private information, such as a physical or electronic
|
30
|
+
address, without explicit permission
|
31
|
+
* Other conduct which could reasonably be considered inappropriate in a
|
32
|
+
professional setting
|
33
|
+
|
34
|
+
## Our Responsibilities
|
35
|
+
|
36
|
+
Project maintainers are responsible for clarifying the standards of acceptable
|
37
|
+
behavior and are expected to take appropriate and fair corrective action in
|
38
|
+
response to any instances of unacceptable behavior.
|
39
|
+
|
40
|
+
Project maintainers have the right and responsibility to remove, edit, or
|
41
|
+
reject comments, commits, code, wiki edits, issues, and other contributions
|
42
|
+
that are not aligned to this Code of Conduct, or to ban temporarily or
|
43
|
+
permanently any contributor for other behaviors that they deem inappropriate,
|
44
|
+
threatening, offensive, or harmful.
|
45
|
+
|
46
|
+
## Scope
|
47
|
+
|
48
|
+
This Code of Conduct applies both within project spaces and in public spaces
|
49
|
+
when an individual is representing the project or its community. Examples of
|
50
|
+
representing a project or community include using an official project e-mail
|
51
|
+
address, posting via an official social media account, or acting as an appointed
|
52
|
+
representative at an online or offline event. Representation of a project may be
|
53
|
+
further defined and clarified by project maintainers.
|
54
|
+
|
55
|
+
## Enforcement
|
56
|
+
|
57
|
+
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
58
|
+
reported by contacting the project team at yoshoku@outlook.com. All
|
59
|
+
complaints will be reviewed and investigated and will result in a response that
|
60
|
+
is deemed necessary and appropriate to the circumstances. The project team is
|
61
|
+
obligated to maintain confidentiality with regard to the reporter of an incident.
|
62
|
+
Further details of specific enforcement policies may be posted separately.
|
63
|
+
|
64
|
+
Project maintainers who do not follow or enforce the Code of Conduct in good
|
65
|
+
faith may face temporary or permanent repercussions as determined by other
|
66
|
+
members of the project's leadership.
|
67
|
+
|
68
|
+
## Attribution
|
69
|
+
|
70
|
+
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
|
71
|
+
available at [http://contributor-covenant.org/version/1/4][version]
|
72
|
+
|
73
|
+
[homepage]: http://contributor-covenant.org
|
74
|
+
[version]: http://contributor-covenant.org/version/1/4/
|
data/Gemfile
ADDED
data/LICENSE.txt
ADDED
@@ -0,0 +1,27 @@
|
|
1
|
+
Copyright (c) 2019 Atsushi Tatsuma
|
2
|
+
All rights reserved.
|
3
|
+
|
4
|
+
Redistribution and use in source and binary forms, with or without
|
5
|
+
modification, are permitted provided that the following conditions are met:
|
6
|
+
|
7
|
+
* Redistributions of source code must retain the above copyright notice, this
|
8
|
+
list of conditions and the following disclaimer.
|
9
|
+
|
10
|
+
* Redistributions in binary form must reproduce the above copyright notice,
|
11
|
+
this list of conditions and the following disclaimer in the documentation
|
12
|
+
and/or other materials provided with the distribution.
|
13
|
+
|
14
|
+
* Neither the name of the copyright holder nor the names of its
|
15
|
+
contributors may be used to endorse or promote products derived from
|
16
|
+
this software without specific prior written permission.
|
17
|
+
|
18
|
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
19
|
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
20
|
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
21
|
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
22
|
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
23
|
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
24
|
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
25
|
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
26
|
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
27
|
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
data/README.md
ADDED
@@ -0,0 +1,92 @@
|
|
1
|
+
# Rumale::SVM
|
2
|
+
|
3
|
+
[![Build Status](https://travis-ci.org/yoshoku/rumale-svm.svg?branch=master)](https://travis-ci.org/yoshoku/rumale-svm)
|
4
|
+
[![Coverage Status](https://coveralls.io/repos/github/yoshoku/rumale-svm/badge.svg?branch=master)](https://coveralls.io/github/yoshoku/rumale-svm?branch=master)
|
5
|
+
[![Gem Version](https://badge.fury.io/rb/rumale-svm.svg)](https://badge.fury.io/rb/rumale-svm)
|
6
|
+
[![BSD 3-Clause License](https://img.shields.io/badge/License-BSD%203--Clause-orange.svg)](https://github.com/yoshoku/rumale-svm/blob/master/LICENSE.txt)
|
7
|
+
[![Documentation](http://img.shields.io/badge/docs-rdoc.info-blue.svg)](https://yoshoku.github.io/rumale-svm/doc/)
|
8
|
+
|
9
|
+
Rumale::SVM provides support vector machine algorithms in
|
10
|
+
[LIBSVM](https://www.csie.ntu.edu.tw/~cjlin/libsvm/) and [LIBLINEAR](https://www.csie.ntu.edu.tw/~cjlin/liblinear/)
|
11
|
+
with [Rumale](https://github.com/yoshoku/rumale) interface.
|
12
|
+
Many machine learning libraries use LIBSVM and LIBLINEAR as background libraries of support vector machine algorithms.
|
13
|
+
On the other hand, Rumale implements support vector machine algorithms based on the mini-batch stochastic gradient descent method
|
14
|
+
implemented in Ruby.
|
15
|
+
Rumale::SVM adds the functions of support vector machine similar to general machine learning libraries to Rumale.
|
16
|
+
|
17
|
+
## Installation
|
18
|
+
|
19
|
+
Add this line to your application's Gemfile:
|
20
|
+
|
21
|
+
```ruby
|
22
|
+
gem 'rumale-svm'
|
23
|
+
```
|
24
|
+
|
25
|
+
And then execute:
|
26
|
+
|
27
|
+
$ bundle
|
28
|
+
|
29
|
+
Or install it yourself as:
|
30
|
+
|
31
|
+
$ gem install rumale-svm
|
32
|
+
|
33
|
+
## Usage
|
34
|
+
Download pendigits dataset from [LIBSVM DATA](https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/) web page.
|
35
|
+
|
36
|
+
```sh
|
37
|
+
$ wget https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/pendigits
|
38
|
+
$ wget https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/pendigits.t
|
39
|
+
```
|
40
|
+
|
41
|
+
Training linear support vector classifier.
|
42
|
+
|
43
|
+
```ruby
|
44
|
+
require 'rumale/svm'
|
45
|
+
require 'rumale/dataset'
|
46
|
+
|
47
|
+
samples, labels = Rumale::Dataset.load_libsvm_file('pendigits')
|
48
|
+
svc = Rumale::SVM::LinearSVC.new(random_seed: 1)
|
49
|
+
svc.fit(samples, labels)
|
50
|
+
|
51
|
+
File.open('svc.dat', 'wb') { |f| f.write(Marshal.dump(svc)) }
|
52
|
+
```
|
53
|
+
|
54
|
+
Evaluate classifiction accuracy on testing datase.
|
55
|
+
|
56
|
+
```ruby
|
57
|
+
require 'rumale/svm'
|
58
|
+
require 'rumale/dataset'
|
59
|
+
|
60
|
+
samples, labels = Rumale::Dataset.load_libsvm_file('pendigits.t')
|
61
|
+
svc = Marshal.load(File.binread('svc.dat'))
|
62
|
+
|
63
|
+
puts "Accuracy: #{svc.score(samples, labels).round(3)}"
|
64
|
+
```
|
65
|
+
|
66
|
+
Execution result.
|
67
|
+
|
68
|
+
```sh
|
69
|
+
$ ruby rumale_svm_train.rb
|
70
|
+
$ ls svc.dat
|
71
|
+
svc.dat
|
72
|
+
$ ruby rumale_svm_test.rb
|
73
|
+
Accuracy: 0.835
|
74
|
+
```
|
75
|
+
|
76
|
+
## Development
|
77
|
+
|
78
|
+
After checking out the repo, run `bin/setup` to install dependencies. Then, run `rake spec` to run the tests. You can also run `bin/console` for an interactive prompt that will allow you to experiment.
|
79
|
+
|
80
|
+
To install this gem onto your local machine, run `bundle exec rake install`. To release a new version, update the version number in `version.rb`, and then run `bundle exec rake release`, which will create a git tag for the version, push git commits and tags, and push the `.gem` file to [rubygems.org](https://rubygems.org).
|
81
|
+
|
82
|
+
## Contributing
|
83
|
+
|
84
|
+
Bug reports and pull requests are welcome on GitHub at https://github.com/yoshoku/rumale-svm. This project is intended to be a safe, welcoming space for collaboration, and contributors are expected to adhere to the [Contributor Covenant](http://contributor-covenant.org) code of conduct.
|
85
|
+
|
86
|
+
## License
|
87
|
+
|
88
|
+
The gem is available as open source under the terms of the [BSD-3-Clause License](https://opensource.org/licenses/BSD-3-Clause).
|
89
|
+
|
90
|
+
## Code of Conduct
|
91
|
+
|
92
|
+
Everyone interacting in the Rumale::Svm project’s codebases, issue trackers, chat rooms and mailing lists is expected to follow the [code of conduct](https://github.com/yoshoku/rumale-svm/blob/master/CODE_OF_CONDUCT.md).
|
data/Rakefile
ADDED
data/bin/console
ADDED
@@ -0,0 +1,14 @@
|
|
1
|
+
#!/usr/bin/env ruby
|
2
|
+
|
3
|
+
require "bundler/setup"
|
4
|
+
require "rumale/svm"
|
5
|
+
|
6
|
+
# You can add fixtures and/or initialization code here to make experimenting
|
7
|
+
# with your gem easier. You can also use a different console, if you like.
|
8
|
+
|
9
|
+
# (If you use this, don't forget to add pry to your Gemfile!)
|
10
|
+
# require "pry"
|
11
|
+
# Pry.start
|
12
|
+
|
13
|
+
require "irb"
|
14
|
+
IRB.start(__FILE__)
|
data/bin/setup
ADDED
data/lib/rumale/svm.rb
ADDED
@@ -0,0 +1,11 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'rumale/svm/version'
|
4
|
+
require 'rumale/svm/svc'
|
5
|
+
require 'rumale/svm/svr'
|
6
|
+
require 'rumale/svm/nu_svc'
|
7
|
+
require 'rumale/svm/nu_svr'
|
8
|
+
require 'rumale/svm/one_class_svm'
|
9
|
+
require 'rumale/svm/linear_svc'
|
10
|
+
require 'rumale/svm/linear_svr'
|
11
|
+
require 'rumale/svm/logistic_regression'
|
@@ -0,0 +1,238 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'numo/liblinear'
|
4
|
+
require 'rumale/base/base_estimator'
|
5
|
+
require 'rumale/base/classifier'
|
6
|
+
require 'rumale/probabilistic_output'
|
7
|
+
|
8
|
+
module Rumale
|
9
|
+
module SVM
|
10
|
+
# LinearSVC is a class that provides Support Vector Classifier in LIBLINEAR with Rumale interface.
|
11
|
+
#
|
12
|
+
# @example
|
13
|
+
# estimator = Rumale::SVM::LinearSVC.new(penalty: 'l2', loss: 'squared_hinge', reg_param: 1.0, random_seed: 1)
|
14
|
+
# estimator.fit(training_samples, traininig_labels)
|
15
|
+
# results = estimator.predict(testing_samples)
|
16
|
+
class LinearSVC
|
17
|
+
include Base::BaseEstimator
|
18
|
+
include Base::Classifier
|
19
|
+
|
20
|
+
# Return the weight vector for LinearSVC.
|
21
|
+
# @return [Numo::DFloat] (shape: [n_classes, n_features])
|
22
|
+
attr_reader :weight_vec
|
23
|
+
|
24
|
+
# Return the bias term (a.k.a. intercept) for LinearSVC.
|
25
|
+
# @return [Numo::DFloat] (shape: [n_classes])
|
26
|
+
attr_reader :bias_term
|
27
|
+
|
28
|
+
# Create a new classifier with Support Vector Classifier.
|
29
|
+
#
|
30
|
+
# @param penalty [String] The type of norm used in the penalization ('l2' or 'l1').
|
31
|
+
# @param loss [String] The type of loss function ('squared_hinge' or 'hinge').
|
32
|
+
# This parameter is ignored if penalty = 'l1'.
|
33
|
+
# @param dual [Boolean] The flag indicating whether to solve dual optimization problem.
|
34
|
+
# When n_samples > n_features, dual = false is more preferable.
|
35
|
+
# This parameter is ignored if loss = 'hinge'.
|
36
|
+
# @param reg_param [Float] The regularization parameter.
|
37
|
+
# @param fit_bias [Boolean] The flag indicating whether to fit the bias term.
|
38
|
+
# @param bias_scale [Float] The scale of the bias term.
|
39
|
+
# This parameter is ignored if fit_bias = false.
|
40
|
+
# @param probability [Boolean] The flag indicating whether to train the parameter for probability estimation.
|
41
|
+
# @param tol [Float] The tolerance of termination criterion.
|
42
|
+
# @param verbose [Boolean] The flag indicating whether to output learning process message
|
43
|
+
# @param random_seed [Integer/Nil] The seed value using to initialize the random generator.
|
44
|
+
def initialize(penalty: 'l2', loss: 'squared_hinge', dual: true, reg_param: 1.0,
|
45
|
+
fit_bias: true, bias_scale: 1.0, probability: false,
|
46
|
+
tol: 1e-3, verbose: false, random_seed: nil)
|
47
|
+
check_params_string(penalty: penalty, loss: loss)
|
48
|
+
check_params_float(reg_param: reg_param, bias_scale: bias_scale, tol: tol)
|
49
|
+
check_params_boolean(dual: dual, fit_bias: fit_bias, probability: probability, verbose: verbose)
|
50
|
+
check_params_type_or_nil(Integer, random_seed: random_seed)
|
51
|
+
@params = {}
|
52
|
+
@params[:penalty] = penalty == 'l1' ? 'l1' : 'l2'
|
53
|
+
@params[:loss] = loss == 'hinge' ? 'hinge' : 'squared_hinge'
|
54
|
+
@params[:dual] = dual
|
55
|
+
@params[:reg_param] = reg_param
|
56
|
+
@params[:fit_bias] = fit_bias
|
57
|
+
@params[:bias_scale] = bias_scale
|
58
|
+
@params[:probability] = probability
|
59
|
+
@params[:tol] = tol
|
60
|
+
@params[:verbose] = verbose
|
61
|
+
@params[:random_seed] = random_seed
|
62
|
+
@model = nil
|
63
|
+
@weight_vec = nil
|
64
|
+
@bias_term = nil
|
65
|
+
@prob_param = nil
|
66
|
+
end
|
67
|
+
|
68
|
+
# Fit the model with given training data.
|
69
|
+
#
|
70
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
|
71
|
+
# @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
|
72
|
+
# @return [LinearSVC] The learned classifier itself.
|
73
|
+
def fit(x, y)
|
74
|
+
check_sample_array(x)
|
75
|
+
check_label_array(y)
|
76
|
+
check_sample_label_size(x, y)
|
77
|
+
xx = fit_bias? ? expand_feature(x) : x
|
78
|
+
@model = Numo::Liblinear.train(xx, y, liblinear_params)
|
79
|
+
@weight_vec, @bias_term = weight_and_bias(@model[:w])
|
80
|
+
@prob_param = proba_model(decision_function(x), y)
|
81
|
+
self
|
82
|
+
end
|
83
|
+
|
84
|
+
# Calculate confidence scores for samples.
|
85
|
+
#
|
86
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to compute the scores.
|
87
|
+
# @return [Numo::DFloat] (shape: [n_samples, n_classes]) Confidence score per sample.
|
88
|
+
def decision_function(x)
|
89
|
+
check_sample_array(x)
|
90
|
+
xx = fit_bias? ? expand_feature(x) : x
|
91
|
+
Numo::Liblinear.decision_function(xx, liblinear_params, @model)
|
92
|
+
end
|
93
|
+
|
94
|
+
# Predict class labels for samples.
|
95
|
+
#
|
96
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels.
|
97
|
+
# @return [Numo::Int32] (shape: [n_samples]) Predicted class label per sample.
|
98
|
+
def predict(x)
|
99
|
+
check_sample_array(x)
|
100
|
+
xx = fit_bias? ? expand_feature(x) : x
|
101
|
+
Numo::Int32.cast(Numo::Liblinear.predict(xx, liblinear_params, @model))
|
102
|
+
end
|
103
|
+
|
104
|
+
# Predict class probability for samples.
|
105
|
+
# This method works correctly only if the probability parameter is true.
|
106
|
+
#
|
107
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the probailities.
|
108
|
+
# @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted probability of each class per sample.
|
109
|
+
def predict_proba(x)
|
110
|
+
check_sample_array(x)
|
111
|
+
if binary_class?
|
112
|
+
probs = Numo::DFloat.zeros(x.shape[0], 2)
|
113
|
+
probs[true, 0] = 1.0 / (Numo::NMath.exp(@prob_param[0] * decision_function(x) + @prob_param[1]) + 1.0)
|
114
|
+
probs[true, 1] = 1.0 - probs[true, 0]
|
115
|
+
else
|
116
|
+
probs = 1.0 / (Numo::NMath.exp(@prob_param[true, 0] * decision_function(x) + @prob_param[true, 1]) + 1.0)
|
117
|
+
probs = (probs.transpose / probs.sum(axis: 1)).transpose.dup
|
118
|
+
end
|
119
|
+
probs
|
120
|
+
end
|
121
|
+
|
122
|
+
# Dump marshal data.
|
123
|
+
# @return [Hash] The marshal data about LinearSVC.
|
124
|
+
def marshal_dump
|
125
|
+
{ params: @params,
|
126
|
+
model: @model,
|
127
|
+
weight_vec: @weight_vec,
|
128
|
+
bias_term: @bias_term,
|
129
|
+
prob_param: @prob_param }
|
130
|
+
end
|
131
|
+
|
132
|
+
# Load marshal data.
|
133
|
+
# @return [nil]
|
134
|
+
def marshal_load(obj)
|
135
|
+
@params = obj[:params]
|
136
|
+
@model = obj[:model]
|
137
|
+
@weight_vec = obj[:weight_vec]
|
138
|
+
@bias_term = obj[:bias_term]
|
139
|
+
@prob_param = obj[:prob_param]
|
140
|
+
nil
|
141
|
+
end
|
142
|
+
|
143
|
+
private
|
144
|
+
|
145
|
+
def expand_feature(x)
|
146
|
+
n_samples = x.shape[0]
|
147
|
+
Numo::NArray.hstack([x, Numo::DFloat.ones([n_samples, 1]) * bias_scale])
|
148
|
+
end
|
149
|
+
|
150
|
+
def weight_and_bias(base_weight)
|
151
|
+
if binary_class?
|
152
|
+
bias_vec = 0.0
|
153
|
+
weight_mat = base_weight.dup
|
154
|
+
if fit_bias?
|
155
|
+
bias_vec = weight_mat[-1]
|
156
|
+
weight_mat = weight_mat[0...-1].dup
|
157
|
+
end
|
158
|
+
else
|
159
|
+
bias_vec = Numo::DFloat.zeros(n_classes)
|
160
|
+
weight_mat = base_weight.reshape(n_features, n_classes).transpose.dup
|
161
|
+
if fit_bias?
|
162
|
+
bias_vec = weight_mat[true, -1].dup
|
163
|
+
weight_mat = weight_mat[true, 0...-1].dup
|
164
|
+
end
|
165
|
+
end
|
166
|
+
[weight_mat, bias_vec]
|
167
|
+
end
|
168
|
+
|
169
|
+
def proba_model(df, y)
|
170
|
+
res = binary_class? ? Numo::DFloat[1, 0] : Numo::DFloat.cast([[1, 0]] * n_classes)
|
171
|
+
return res unless fit_probability?
|
172
|
+
|
173
|
+
if binary_class?
|
174
|
+
bin_y = Numo::Int32.cast(y.eq(labels[0])) * 2 - 1
|
175
|
+
res = Rumale::ProbabilisticOutput.fit_sigmoid(df, bin_y)
|
176
|
+
else
|
177
|
+
labels.each_with_index do |c, n|
|
178
|
+
bin_y = Numo::Int32.cast(y.eq(c)) * 2 - 1
|
179
|
+
res[n, true] = Rumale::ProbabilisticOutput.fit_sigmoid(df[true, n], bin_y)
|
180
|
+
end
|
181
|
+
end
|
182
|
+
res
|
183
|
+
end
|
184
|
+
|
185
|
+
def liblinear_params
|
186
|
+
res = {}
|
187
|
+
res[:solver_type] = solver_type
|
188
|
+
res[:eps] = @params[:tol]
|
189
|
+
res[:C] = @params[:reg_param]
|
190
|
+
res[:verbose] = @params[:verbose]
|
191
|
+
res[:random_seed] = @params[:random_seed]
|
192
|
+
res
|
193
|
+
end
|
194
|
+
|
195
|
+
def solver_type
|
196
|
+
return Numo::Liblinear::SolverType::L1R_L2LOSS_SVC if @params[:penalty] == 'l1'
|
197
|
+
|
198
|
+
if @params[:loss] == 'squared_hinge'
|
199
|
+
if @params[:dual]
|
200
|
+
Numo::Liblinear::SolverType::L2R_L2LOSS_SVC_DUAL
|
201
|
+
else
|
202
|
+
Numo::Liblinear::SolverType::L2R_L2LOSS_SVC
|
203
|
+
end
|
204
|
+
else
|
205
|
+
Numo::Liblinear::SolverType::L2R_L1LOSS_SVC_DUAL
|
206
|
+
end
|
207
|
+
end
|
208
|
+
|
209
|
+
def binary_class?
|
210
|
+
@model[:nr_class] == 2
|
211
|
+
end
|
212
|
+
|
213
|
+
def fit_probability?
|
214
|
+
@params[:probability]
|
215
|
+
end
|
216
|
+
|
217
|
+
def fit_bias?
|
218
|
+
@params[:fit_bias]
|
219
|
+
end
|
220
|
+
|
221
|
+
def bias_scale
|
222
|
+
@params[:bias_scale]
|
223
|
+
end
|
224
|
+
|
225
|
+
def n_classes
|
226
|
+
@model[:nr_class]
|
227
|
+
end
|
228
|
+
|
229
|
+
def n_features
|
230
|
+
@model[:nr_feature]
|
231
|
+
end
|
232
|
+
|
233
|
+
def labels
|
234
|
+
@model[:label]
|
235
|
+
end
|
236
|
+
end
|
237
|
+
end
|
238
|
+
end
|