jrb-libsvm 0.1.2-java
Sign up to get free protection for your applications and to get access to all the features.
- data/.gitignore +16 -0
- data/.rspec +3 -0
- data/.ruby-version +1 -0
- data/.travis.yml +4 -0
- data/.versions.conf +4 -0
- data/Gemfile +9 -0
- data/LIBSVM-LICENSE +30 -0
- data/MIT-LICENSE +22 -0
- data/README.md +105 -0
- data/Rakefile +16 -0
- data/java/3-11_w_squared/Svm.java +2826 -0
- data/java/COPYRIGHT +31 -0
- data/java/libsvm/Model.java +23 -0
- data/java/libsvm/Node.java +6 -0
- data/java/libsvm/Parameter.java +47 -0
- data/java/libsvm/PrintInterface.java +5 -0
- data/java/libsvm/Problem.java +7 -0
- data/java/libsvm/Svm.java +2814 -0
- data/jrb-libsvm.gemspec +23 -0
- data/lib/java/libsvm.jar +0 -0
- data/lib/jrb-libsvm/model.rb +97 -0
- data/lib/jrb-libsvm/node.rb +37 -0
- data/lib/jrb-libsvm/parameter.rb +66 -0
- data/lib/jrb-libsvm/problem.rb +35 -0
- data/lib/jrb-libsvm/version.rb +3 -0
- data/lib/jrb-libsvm.rb +31 -0
- data/spec/model_spec.rb +119 -0
- data/spec/node_spec.rb +62 -0
- data/spec/parameter_spec.rb +79 -0
- data/spec/problem_spec.rb +37 -0
- data/spec/spec_helper.rb +9 -0
- data/spec/usage_spec.rb +47 -0
- data/tmp/.gitkeep +0 -0
- metadata +108 -0
data/jrb-libsvm.gemspec
ADDED
@@ -0,0 +1,23 @@
|
|
1
|
+
# -*- encoding: utf-8 -*-
|
2
|
+
lib = File.expand_path('../lib', __FILE__)
|
3
|
+
$LOAD_PATH.unshift(lib) unless $LOAD_PATH.include?(lib)
|
4
|
+
require 'jrb-libsvm/version'
|
5
|
+
|
6
|
+
Gem::Specification.new do |gem|
|
7
|
+
gem.name = "jrb-libsvm"
|
8
|
+
gem.version = JrbLibsvm::VERSION
|
9
|
+
gem.platform = 'java'
|
10
|
+
gem.authors = ["Andreas Eger"]
|
11
|
+
gem.email = ["dev@eger-andreas.de"]
|
12
|
+
gem.description = %q{JRuby language bindings for libsvm}
|
13
|
+
gem.summary = %q{basic wrapper around the java libsvm libary}
|
14
|
+
gem.homepage = "https://github.com/sch1zo/jrb-libsvm"
|
15
|
+
|
16
|
+
gem.files = `git ls-files`.split($/)
|
17
|
+
gem.executables = gem.files.grep(%r{^bin/}).map{ |f| File.basename(f) }
|
18
|
+
gem.test_files = gem.files.grep(%r{^(test|spec|features)/})
|
19
|
+
gem.require_paths = ["lib"]
|
20
|
+
|
21
|
+
# specify any dependencies here
|
22
|
+
gem.add_development_dependency('rspec', '>= 2.7.0')
|
23
|
+
end
|
data/lib/java/libsvm.jar
ADDED
Binary file
|
@@ -0,0 +1,97 @@
|
|
1
|
+
java_import 'java.io.DataOutputStream'
|
2
|
+
java_import 'java.io.ByteArrayOutputStream'
|
3
|
+
|
4
|
+
java_import 'java.io.StringReader'
|
5
|
+
java_import 'java.io.BufferedReader'
|
6
|
+
|
7
|
+
module Libsvm
|
8
|
+
class Model
|
9
|
+
class << self
|
10
|
+
def train(problem, parameter)
|
11
|
+
return Svm.svm_train(problem, parameter)
|
12
|
+
end
|
13
|
+
end
|
14
|
+
|
15
|
+
# Return the SVM problem type for this model
|
16
|
+
def svm_type
|
17
|
+
self.param.svm_type
|
18
|
+
end
|
19
|
+
|
20
|
+
# Return the kernel type for this model
|
21
|
+
def kernel_type
|
22
|
+
self.param.kernel_type
|
23
|
+
end
|
24
|
+
|
25
|
+
# Return the value of the degree parameter
|
26
|
+
def degree
|
27
|
+
self.param.degree
|
28
|
+
end
|
29
|
+
|
30
|
+
# Return the value of the gamma parameter
|
31
|
+
def gamma
|
32
|
+
self.param.gamma
|
33
|
+
end
|
34
|
+
|
35
|
+
# Return the value of the cost parameter
|
36
|
+
def cost
|
37
|
+
self.param.c
|
38
|
+
end
|
39
|
+
|
40
|
+
# Return the number of classes handled by this model.
|
41
|
+
def classes
|
42
|
+
self.nr_class
|
43
|
+
end
|
44
|
+
|
45
|
+
# Save model to given filename.
|
46
|
+
# Raises IOError on any error.
|
47
|
+
def save filename
|
48
|
+
Svm.svm_save_model(filename, self)
|
49
|
+
rescue e = java.io.IOException
|
50
|
+
raise IOError.new "Error in saving SVM model to file: #{e}"
|
51
|
+
end
|
52
|
+
|
53
|
+
# Serialize model and return a string
|
54
|
+
def serialize
|
55
|
+
stream = ByteArrayOutputStream.new
|
56
|
+
do_stream = DataOutputStream.new(stream)
|
57
|
+
Svm.svm_save_model(do_stream, self)
|
58
|
+
stream.to_s
|
59
|
+
rescue java.io.IOException
|
60
|
+
raise IOError.new "Error in serializing SVM model"
|
61
|
+
end
|
62
|
+
def to_s
|
63
|
+
serialize
|
64
|
+
end
|
65
|
+
|
66
|
+
# Load model from given filename.
|
67
|
+
# Raises IOError on any error.
|
68
|
+
def self.load filename
|
69
|
+
Svm.svm_load_model(filename)
|
70
|
+
rescue java.io.IOException
|
71
|
+
raise IOError.new "Error in loading SVM model from file"
|
72
|
+
end
|
73
|
+
|
74
|
+
# Load model from string.
|
75
|
+
def self.parse string
|
76
|
+
reader = BufferedReader.new(StringReader.new(string))
|
77
|
+
Svm.svm_load_model(reader)
|
78
|
+
rescue java.io.IOException
|
79
|
+
raise IOError.new "Error in loading SVM model from string"
|
80
|
+
end
|
81
|
+
|
82
|
+
def predict(example, &block)
|
83
|
+
if block.nil?
|
84
|
+
return Svm.svm_predict(self, example)
|
85
|
+
else
|
86
|
+
prediction, probabilities = predict_probability(example)
|
87
|
+
yield probabilities
|
88
|
+
return prediction
|
89
|
+
end
|
90
|
+
end
|
91
|
+
|
92
|
+
def predict_probability(example)
|
93
|
+
probabilities = Java::double[self.classes].new
|
94
|
+
return Svm.svm_predict_probability(self, example, probabilities), probabilities
|
95
|
+
end
|
96
|
+
end
|
97
|
+
end
|
@@ -0,0 +1,37 @@
|
|
1
|
+
module Libsvm
|
2
|
+
class Node
|
3
|
+
class << self
|
4
|
+
def features(*vargs)
|
5
|
+
array_of_nodes = []
|
6
|
+
if vargs.size == 1
|
7
|
+
if vargs.first.class == Array
|
8
|
+
vargs.first.each_with_index do |value, index|
|
9
|
+
array_of_nodes << Node.new(index.to_i, value.to_f)
|
10
|
+
end
|
11
|
+
elsif vargs.first.class == Hash
|
12
|
+
vargs.first.each do |index, value|
|
13
|
+
array_of_nodes << Node.new(index.to_i, value.to_f)
|
14
|
+
end
|
15
|
+
else
|
16
|
+
raise(ArgumentError.new("Node features need to be a Hash, Array or Floats"))
|
17
|
+
end
|
18
|
+
else
|
19
|
+
vargs.each_with_index do |value, index|
|
20
|
+
array_of_nodes << Node.new(index.to_i, value.to_f)
|
21
|
+
end
|
22
|
+
end
|
23
|
+
array_of_nodes
|
24
|
+
end
|
25
|
+
end
|
26
|
+
|
27
|
+
def initialize(index=0, value=0.0)
|
28
|
+
super()
|
29
|
+
self.index = index
|
30
|
+
self.value = value
|
31
|
+
end
|
32
|
+
|
33
|
+
def ==(other)
|
34
|
+
index == other.index && value == other.value
|
35
|
+
end
|
36
|
+
end
|
37
|
+
end
|
@@ -0,0 +1,66 @@
|
|
1
|
+
module Libsvm
|
2
|
+
class Parameter
|
3
|
+
alias :c :C
|
4
|
+
alias :c= :C=
|
5
|
+
|
6
|
+
def label_weights
|
7
|
+
Hash[self.weight_label.zip(self.weight)]
|
8
|
+
end
|
9
|
+
def label_weights=v
|
10
|
+
self.nr_weight = v.keys.size
|
11
|
+
self.weight_label = v.keys.to_java :int
|
12
|
+
self.weight = v.values.to_java :double
|
13
|
+
end
|
14
|
+
|
15
|
+
# Constructor sets up values of attributes based on provided map.
|
16
|
+
# Valid keys with their default values:
|
17
|
+
# * :svm_type = Parameter::C_SVC, for the type of SVM
|
18
|
+
# * :kernel_type = Parameter::LINEAR, for the type of kernel
|
19
|
+
# * :cost = 1.0, for the cost or C parameter
|
20
|
+
# * :gamma = 0.0, for the gamma parameter in kernel
|
21
|
+
# * :degree = 1, for polynomial kernel
|
22
|
+
# * :coef0 = 0.0, for polynomial/sigmoid kernels
|
23
|
+
# * :eps = 0.001, for stopping criterion
|
24
|
+
# * :nr_weight = 0, for C_SVC
|
25
|
+
# * :nu = 0.5, used for NU_SVC, ONE_CLASS and NU_SVR. Nu must be in (0,1]
|
26
|
+
# * :p = 0.1, used for EPSILON_SVR
|
27
|
+
# * :shrinking = 1, use the shrinking heuristics
|
28
|
+
# * :probability = 0, use the probability estimates
|
29
|
+
def initialize args={}
|
30
|
+
super()
|
31
|
+
self.svm_type = args.fetch(:svm_type, Parameter::C_SVC)
|
32
|
+
self.kernel_type = args.fetch(:kernel_type, Parameter::LINEAR)
|
33
|
+
self.C = args.fetch(:cost, 1.0)
|
34
|
+
self.gamma = args.fetch(:gamma, 0.0)
|
35
|
+
self.degree = args.fetch(:degree, 1)
|
36
|
+
self.coef0 = args.fetch(:coef0, 0.0)
|
37
|
+
self.eps = args.fetch(:eps, 0.001)
|
38
|
+
self.nr_weight = args.fetch(:nr_weight, 0)
|
39
|
+
self.nu = args.fetch(:nu, 0.5)
|
40
|
+
self.p = args.fetch(:p, 0.1)
|
41
|
+
self.shrinking = args.fetch(:shrinking, 1)
|
42
|
+
self.probability = args.fetch(:probability, 0)
|
43
|
+
self.cache_size = args.fetch(:cache_size, 1)
|
44
|
+
|
45
|
+
unless self.nu > 0.0 and self.nu <= 1.0
|
46
|
+
raise ArgumentError "Invalid value of nu #{self.nu}, should be in (0,1]"
|
47
|
+
end
|
48
|
+
end
|
49
|
+
end
|
50
|
+
SvmParameter = Parameter
|
51
|
+
|
52
|
+
module KernelType
|
53
|
+
LINEAR = Parameter::LINEAR
|
54
|
+
POLY = Parameter::POLY
|
55
|
+
RBF = Parameter::RBF
|
56
|
+
SIGMOID = Parameter::SIGMOID
|
57
|
+
PRECOMPUTED = Parameter::PRECOMPUTED
|
58
|
+
end
|
59
|
+
module SvmType
|
60
|
+
C_SVC = Parameter::C_SVC
|
61
|
+
NU_SVC = Parameter::NU_SVC
|
62
|
+
ONE_CLASS = Parameter::ONE_CLASS
|
63
|
+
EPSILON_SVR = Parameter::EPSILON_SVR
|
64
|
+
NU_SVR = Parameter::NU_SVR
|
65
|
+
end
|
66
|
+
end
|
@@ -0,0 +1,35 @@
|
|
1
|
+
module Libsvm
|
2
|
+
class Problem
|
3
|
+
def set_examples(labels, features)
|
4
|
+
unless features.size == labels.size
|
5
|
+
raise ArgumentError.new "Number of features must equal number of labels"
|
6
|
+
end
|
7
|
+
unless features.size > 0
|
8
|
+
raise ArgumentError.new "There must be at least one feature."
|
9
|
+
end
|
10
|
+
unless features.collect {|i| i.size}.min == features.collect {|i| i.size}.max
|
11
|
+
raise ArgumentError.new "All features must have the same size"
|
12
|
+
end
|
13
|
+
|
14
|
+
self.l = labels.size
|
15
|
+
# -- add in the training data
|
16
|
+
self.x = Node[features.size, features[0].size].new
|
17
|
+
features.each.with_index do |feature, i|
|
18
|
+
feature.each.with_index do |node, j|
|
19
|
+
self.x[i][j] = node
|
20
|
+
end
|
21
|
+
end
|
22
|
+
|
23
|
+
# -- add in the labels
|
24
|
+
self.y = Java::double[labels.size].new
|
25
|
+
labels.each.with_index do |label, i|
|
26
|
+
self.y[i] = label
|
27
|
+
end
|
28
|
+
|
29
|
+
return labels.size
|
30
|
+
end
|
31
|
+
def examples
|
32
|
+
return self.y, self.x
|
33
|
+
end
|
34
|
+
end
|
35
|
+
end
|
data/lib/jrb-libsvm.rb
ADDED
@@ -0,0 +1,31 @@
|
|
1
|
+
require_relative "jrb-libsvm/version"
|
2
|
+
require "java"
|
3
|
+
require_relative "java/libsvm"
|
4
|
+
|
5
|
+
module Libsvm
|
6
|
+
java_import "libsvm.Parameter"
|
7
|
+
java_import "libsvm.Model"
|
8
|
+
java_import "libsvm.Problem"
|
9
|
+
java_import "libsvm.Node"
|
10
|
+
java_import "libsvm.Svm"
|
11
|
+
|
12
|
+
module CoreExtensions
|
13
|
+
module Collection
|
14
|
+
def to_example
|
15
|
+
Node.features(self)
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
19
|
+
end
|
20
|
+
|
21
|
+
require_relative 'jrb-libsvm/parameter'
|
22
|
+
require_relative 'jrb-libsvm/model'
|
23
|
+
require_relative 'jrb-libsvm/node'
|
24
|
+
require_relative 'jrb-libsvm/problem'
|
25
|
+
|
26
|
+
class Hash
|
27
|
+
include Libsvm::CoreExtensions::Collection
|
28
|
+
end
|
29
|
+
class Array
|
30
|
+
include Libsvm::CoreExtensions::Collection
|
31
|
+
end
|
data/spec/model_spec.rb
ADDED
@@ -0,0 +1,119 @@
|
|
1
|
+
require "spec_helper"
|
2
|
+
|
3
|
+
describe Model do
|
4
|
+
def create_example
|
5
|
+
Node.features(0.2, 0.3, 0.4, 0.5)
|
6
|
+
end
|
7
|
+
|
8
|
+
def create_problem
|
9
|
+
problem = Problem.new
|
10
|
+
features = [Node.features([0.2,0.3,0.4,0.4]),
|
11
|
+
Node.features([0.1,0.5,0.1,0.9]),
|
12
|
+
Node.features([0.2,0.2,0.6,0.5]),
|
13
|
+
Node.features([0.3,0.1,0.5,0.9])]
|
14
|
+
problem.set_examples([1,2,1,2], features)
|
15
|
+
problem
|
16
|
+
end
|
17
|
+
|
18
|
+
def create_parameter
|
19
|
+
parameter = Parameter.new
|
20
|
+
parameter.cache_size = 50 # mb
|
21
|
+
parameter.eps = 0.01
|
22
|
+
parameter.c = 10
|
23
|
+
parameter.probability = 1
|
24
|
+
parameter
|
25
|
+
end
|
26
|
+
|
27
|
+
context "The class interface" do
|
28
|
+
before(:each) do
|
29
|
+
@problem = create_problem
|
30
|
+
@parameter = create_parameter
|
31
|
+
end
|
32
|
+
|
33
|
+
it "results from training on a problem under a certain parameter set" do
|
34
|
+
model = Model.train(@problem,@parameter)
|
35
|
+
model.should_not be_nil
|
36
|
+
end
|
37
|
+
end
|
38
|
+
|
39
|
+
context "A saved model" do
|
40
|
+
before(:each) do
|
41
|
+
@filename = "tmp/svm_model.model"
|
42
|
+
model = Model.train(create_problem, create_parameter)
|
43
|
+
model.save(@filename)
|
44
|
+
@model_string = model.to_s
|
45
|
+
end
|
46
|
+
|
47
|
+
it "can be loaded from a file" do
|
48
|
+
model = Model.load(@filename)
|
49
|
+
model.should_not be_nil
|
50
|
+
end
|
51
|
+
|
52
|
+
it "can be loaded from a string" do
|
53
|
+
model = Model.parse @model_string
|
54
|
+
model.should_not be_nil
|
55
|
+
end
|
56
|
+
|
57
|
+
it "should do the same for load and load_from_string" do
|
58
|
+
model = Model.load @filename
|
59
|
+
model2 = Model.parse @model_string
|
60
|
+
model.serialize.should == model2.serialize
|
61
|
+
end
|
62
|
+
|
63
|
+
after(:each) do
|
64
|
+
File.delete(@filename) rescue nil
|
65
|
+
end
|
66
|
+
end
|
67
|
+
|
68
|
+
context "An Libsvm model" do
|
69
|
+
before(:each) do
|
70
|
+
@problem = create_problem
|
71
|
+
@parameter = create_parameter
|
72
|
+
@model = Model.train(@problem, @parameter)
|
73
|
+
@file_path = "tmp/svm_model.model"
|
74
|
+
end
|
75
|
+
after(:each) do
|
76
|
+
File.delete(@file_path) if File.exists?(@file_path)
|
77
|
+
end
|
78
|
+
|
79
|
+
it "can be saved to a file" do
|
80
|
+
@model.save(@file_path)
|
81
|
+
File.exist?(@file_path).should be_true
|
82
|
+
end
|
83
|
+
|
84
|
+
it "can be serialized to a string" do
|
85
|
+
@model.serialize.should_not be_empty
|
86
|
+
end
|
87
|
+
|
88
|
+
it "should generate the same text for serialize and save" do
|
89
|
+
@model.save(@file_path)
|
90
|
+
@model.serialize.should == IO.read(@file_path)
|
91
|
+
end
|
92
|
+
|
93
|
+
it "can be asked for it's svm_type" do
|
94
|
+
@model.svm_type.should == Parameter::C_SVC
|
95
|
+
end
|
96
|
+
|
97
|
+
it "can be asked for it's number of classes (aka. labels)" do
|
98
|
+
@model.classes.should == 2
|
99
|
+
end
|
100
|
+
|
101
|
+
it "can predict" do
|
102
|
+
prediction = @model.predict(create_example)
|
103
|
+
prediction.should_not be_nil
|
104
|
+
end
|
105
|
+
it "can predict probability" do
|
106
|
+
prediction, probabilities = @model.predict_probability(create_example)
|
107
|
+
prediction.should_not be_nil
|
108
|
+
probabilities.should have(@model.classes).items
|
109
|
+
probabilities.each { |e| e.should_not be_nil }
|
110
|
+
end
|
111
|
+
it "can predict with block" do
|
112
|
+
prediction = @model.predict(create_example) do |probabilities|
|
113
|
+
probabilities.should be_all { |p| p.kind_of? Float }
|
114
|
+
probabilities.count.should == @model.classes
|
115
|
+
end
|
116
|
+
prediction.should be_a Float
|
117
|
+
end
|
118
|
+
end
|
119
|
+
end
|
data/spec/node_spec.rb
ADDED
@@ -0,0 +1,62 @@
|
|
1
|
+
require "spec_helper"
|
2
|
+
|
3
|
+
describe Node do
|
4
|
+
context "construction" do
|
5
|
+
it "using the properties" do
|
6
|
+
n = Node.new
|
7
|
+
n.index = 11
|
8
|
+
n.value = 0.11
|
9
|
+
n.index.should == 11
|
10
|
+
n.value.should be_within(0.0001).of(0.11)
|
11
|
+
end
|
12
|
+
|
13
|
+
it "using the constructor parameters" do
|
14
|
+
n = Node.new(14, 0.14)
|
15
|
+
n.index.should == 14
|
16
|
+
n.value.should be_within(0.0001).of(0.14)
|
17
|
+
end
|
18
|
+
end
|
19
|
+
|
20
|
+
context "inner workings" do
|
21
|
+
let(:node) {Node.new}
|
22
|
+
|
23
|
+
it "can be created" do
|
24
|
+
node.should_not be_nil
|
25
|
+
end
|
26
|
+
|
27
|
+
it "does not segfault on setting properties" do
|
28
|
+
node.index = 99
|
29
|
+
node.index.should == 99
|
30
|
+
node.value = 3.141
|
31
|
+
node.value.should be_within(0.00001).of(3.141)
|
32
|
+
end
|
33
|
+
|
34
|
+
it "has inited properties" do
|
35
|
+
node.index.should == 0
|
36
|
+
node.value.should be_within(0.00001).of(0)
|
37
|
+
end
|
38
|
+
|
39
|
+
it "class can create nodes from an array" do
|
40
|
+
ary = Node.features([0.1, 0.2, 0.3, 0.4, 0.5])
|
41
|
+
ary.map {|n| n.class.should == Node}
|
42
|
+
ary.map {|n| n.value }.should == [0.1, 0.2, 0.3, 0.4, 0.5]
|
43
|
+
end
|
44
|
+
|
45
|
+
it "class can create nodes from variable parameters" do
|
46
|
+
ary = Node.features(0.1, 0.2, 0.3, 0.4, 0.5)
|
47
|
+
ary.map {|n| Node.should === n }
|
48
|
+
ary.map {|n| n.value }.should == [0.1, 0.2, 0.3, 0.4, 0.5]
|
49
|
+
end
|
50
|
+
|
51
|
+
it "class can create nodes from hash" do
|
52
|
+
ary = Node.features(3=>0.3, 5=>0.5, 6=>0.6, 10=>1.0)
|
53
|
+
ary.map {|n| n.class.should == Node}
|
54
|
+
ary.map {|n| n.value }.sort.should == [0.3, 0.5, 0.6, 1.0]
|
55
|
+
ary.map {|n| n.index }.sort.should == [3, 5, 6, 10]
|
56
|
+
end
|
57
|
+
|
58
|
+
it "implements a value-like equality, not identity-notion" do
|
59
|
+
Node.new(1, 0.1).should == Node.new(1, 0.1)
|
60
|
+
end
|
61
|
+
end
|
62
|
+
end
|
@@ -0,0 +1,79 @@
|
|
1
|
+
require "spec_helper"
|
2
|
+
|
3
|
+
describe Parameter do
|
4
|
+
before do
|
5
|
+
@p = Libsvm::SvmParameter.new
|
6
|
+
end
|
7
|
+
it "can be created with a constructor" do
|
8
|
+
->{Libsvm::SvmParameter.new(svm_type: Libsvm::SvmType::C_SVC, cost: 23, gamma: 65)}.should_not raise_error
|
9
|
+
end
|
10
|
+
it "int svm_type" do
|
11
|
+
SvmType::C_SVC.should == 0
|
12
|
+
@p.svm_type = SvmType::C_SVC
|
13
|
+
@p.svm_type.should == SvmType::C_SVC
|
14
|
+
end
|
15
|
+
|
16
|
+
it "int kernel_type" do
|
17
|
+
KernelType::RBF.should == 2
|
18
|
+
@p.kernel_type = KernelType::RBF
|
19
|
+
@p.kernel_type.should == KernelType::RBF
|
20
|
+
end
|
21
|
+
|
22
|
+
it "int degree" do
|
23
|
+
@p.degree = 99
|
24
|
+
@p.degree.should == 99
|
25
|
+
end
|
26
|
+
|
27
|
+
it "double gamma" do
|
28
|
+
@p.gamma = 0.33
|
29
|
+
@p.gamma.should == 0.33
|
30
|
+
end
|
31
|
+
|
32
|
+
it "double coef0" do
|
33
|
+
@p.coef0 = 0.99
|
34
|
+
@p.coef0.should == 0.99
|
35
|
+
end
|
36
|
+
|
37
|
+
it "double cache_size" do
|
38
|
+
@p.cache_size = 0.77
|
39
|
+
@p.cache_size.should == 0.77
|
40
|
+
end
|
41
|
+
|
42
|
+
it "double eps" do
|
43
|
+
@p.eps = 0.111
|
44
|
+
@p.eps.should == 0.111
|
45
|
+
@p.eps = 0.112
|
46
|
+
@p.eps.should == 0.112
|
47
|
+
end
|
48
|
+
|
49
|
+
it "double C" do
|
50
|
+
@p.c = 3.141
|
51
|
+
@p.c.should == 3.141
|
52
|
+
end
|
53
|
+
|
54
|
+
it "can set and read weights (weight, weight_label, nr_weight members from struct)" do
|
55
|
+
@p.label_weights = {1=> 1.2, 3=>0.2, 5=>0.888}
|
56
|
+
@p.label_weights.should == {1=> 1.2, 3=>0.2, 5=>0.888}
|
57
|
+
end
|
58
|
+
|
59
|
+
|
60
|
+
it "double nu" do
|
61
|
+
@p.nu = 1.1
|
62
|
+
@p.nu.should == 1.1
|
63
|
+
end
|
64
|
+
|
65
|
+
it "double p" do
|
66
|
+
@p.p = 0.123
|
67
|
+
@p.p.should == 0.123
|
68
|
+
end
|
69
|
+
|
70
|
+
it "int shrinking" do
|
71
|
+
@p.shrinking = 22
|
72
|
+
@p.shrinking.should == 22
|
73
|
+
end
|
74
|
+
|
75
|
+
it "int probability" do
|
76
|
+
@p.probability = 35
|
77
|
+
@p.probability.should == 35
|
78
|
+
end
|
79
|
+
end
|
@@ -0,0 +1,37 @@
|
|
1
|
+
require 'spec_helper'
|
2
|
+
|
3
|
+
describe Problem do
|
4
|
+
before(:each) do
|
5
|
+
@problem = Problem.new
|
6
|
+
@features = [ Node.features(0.2,0.3,0.4,0.4),
|
7
|
+
Node.features(0.1,0.5,0.1,0.9),
|
8
|
+
Node.features(0.2,0.2,0.6,0.5),
|
9
|
+
Node.features(0.3,0.1,0.5,0.9) ]
|
10
|
+
end
|
11
|
+
|
12
|
+
it "examples get stored and retrieved" do
|
13
|
+
@problem.set_examples([1,2,1,2], @features)
|
14
|
+
labels, examples = @problem.examples
|
15
|
+
labels.size.should == 4
|
16
|
+
examples.size.should == 4
|
17
|
+
examples.map {|x|x.size}.should == [4,4,4,4]
|
18
|
+
examples.first.map {|node| node.index}.should == [0,1,2,3]
|
19
|
+
examples.first.map {|node| node.value}.should == [0.2,0.3,0.4,0.4]
|
20
|
+
end
|
21
|
+
|
22
|
+
it "can be populated" do
|
23
|
+
examples = [Node.features(0.2,0.3,0.4,0.4),
|
24
|
+
Node.features(0.1,0.5,0.1,0.9),
|
25
|
+
Node.features(0.2,0.2,0.6,0.5),
|
26
|
+
Node.features(0.3,0.1,0.5,0.9)]
|
27
|
+
->{@problem.set_examples([1,2,1,2], examples)}.should_not raise_error
|
28
|
+
end
|
29
|
+
|
30
|
+
it "can be set twice over" do
|
31
|
+
features = [Node.features(0.2, 0.3, 0.4, 0.4), Node.features(0.3,0.1,0.5,0.9)]
|
32
|
+
@problem.set_examples([1,2], features)
|
33
|
+
features = [Node.features(0.2, 0.3, 0.4, 0.4), Node.features(0.3,0.1,0.5,0.9)]
|
34
|
+
->{@problem.set_examples([8,2], features)}.should_not raise_error
|
35
|
+
end
|
36
|
+
|
37
|
+
end
|
data/spec/spec_helper.rb
ADDED
data/spec/usage_spec.rb
ADDED
@@ -0,0 +1,47 @@
|
|
1
|
+
require 'spec_helper'
|
2
|
+
|
3
|
+
describe "Basic usage" do
|
4
|
+
before do
|
5
|
+
@problem = Problem.new
|
6
|
+
@parameter = Parameter.new
|
7
|
+
@parameter.cache_size = 1 # mb
|
8
|
+
|
9
|
+
# "eps is the stopping criterion (we usually use 0.00001 in nu-SVC,
|
10
|
+
# 0.001 in others)." (from README)
|
11
|
+
@parameter.eps = 0.001
|
12
|
+
|
13
|
+
@parameter.c = 10
|
14
|
+
end
|
15
|
+
|
16
|
+
it "has a nice API" do
|
17
|
+
example = {11 => 0.11, 21 => 0.21, 101 => 0.99 }.to_example
|
18
|
+
example.should == Node.features({11 => 0.11, 21 => 0.21, 101 => 0.99 })
|
19
|
+
end
|
20
|
+
|
21
|
+
it "is as in [PCI,217]" do
|
22
|
+
examples = [ [1,0,1], [-1,0,-1] ].map {|ary| Node.features(ary) }
|
23
|
+
labels = [1, -1]
|
24
|
+
@problem.set_examples(labels, examples)
|
25
|
+
|
26
|
+
model = Model.train(@problem, @parameter)
|
27
|
+
|
28
|
+
pred = model.predict(Node.features(1, 1, 1))
|
29
|
+
pred.should == 1.0
|
30
|
+
|
31
|
+
pred = model.predict(Node.features(-1, 1, -1))
|
32
|
+
pred.should == -1.0
|
33
|
+
|
34
|
+
pred = model.predict(Node.features(-1, 55, -1))
|
35
|
+
pred.should == -1.0
|
36
|
+
end
|
37
|
+
|
38
|
+
it "kernel parameter use" do
|
39
|
+
@parameter.kernel_type = Parameter::RBF
|
40
|
+
examples = [[1, 2, 3], [-2, -2, -2]].map {|ary| Node.features(ary) }
|
41
|
+
@problem.set_examples([1, 2], examples)
|
42
|
+
|
43
|
+
model = Model.train(@problem, @parameter)
|
44
|
+
|
45
|
+
model.predict(Node.features(1, 2, 3)).should == 2
|
46
|
+
end
|
47
|
+
end
|
data/tmp/.gitkeep
ADDED
File without changes
|