jrb-libsvm 0.1.2-java
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data/.gitignore +16 -0
- data/.rspec +3 -0
- data/.ruby-version +1 -0
- data/.travis.yml +4 -0
- data/.versions.conf +4 -0
- data/Gemfile +9 -0
- data/LIBSVM-LICENSE +30 -0
- data/MIT-LICENSE +22 -0
- data/README.md +105 -0
- data/Rakefile +16 -0
- data/java/3-11_w_squared/Svm.java +2826 -0
- data/java/COPYRIGHT +31 -0
- data/java/libsvm/Model.java +23 -0
- data/java/libsvm/Node.java +6 -0
- data/java/libsvm/Parameter.java +47 -0
- data/java/libsvm/PrintInterface.java +5 -0
- data/java/libsvm/Problem.java +7 -0
- data/java/libsvm/Svm.java +2814 -0
- data/jrb-libsvm.gemspec +23 -0
- data/lib/java/libsvm.jar +0 -0
- data/lib/jrb-libsvm/model.rb +97 -0
- data/lib/jrb-libsvm/node.rb +37 -0
- data/lib/jrb-libsvm/parameter.rb +66 -0
- data/lib/jrb-libsvm/problem.rb +35 -0
- data/lib/jrb-libsvm/version.rb +3 -0
- data/lib/jrb-libsvm.rb +31 -0
- data/spec/model_spec.rb +119 -0
- data/spec/node_spec.rb +62 -0
- data/spec/parameter_spec.rb +79 -0
- data/spec/problem_spec.rb +37 -0
- data/spec/spec_helper.rb +9 -0
- data/spec/usage_spec.rb +47 -0
- data/tmp/.gitkeep +0 -0
- metadata +108 -0
data/jrb-libsvm.gemspec
ADDED
@@ -0,0 +1,23 @@
|
|
1
|
+
# -*- encoding: utf-8 -*-
|
2
|
+
lib = File.expand_path('../lib', __FILE__)
|
3
|
+
$LOAD_PATH.unshift(lib) unless $LOAD_PATH.include?(lib)
|
4
|
+
require 'jrb-libsvm/version'
|
5
|
+
|
6
|
+
Gem::Specification.new do |gem|
|
7
|
+
gem.name = "jrb-libsvm"
|
8
|
+
gem.version = JrbLibsvm::VERSION
|
9
|
+
gem.platform = 'java'
|
10
|
+
gem.authors = ["Andreas Eger"]
|
11
|
+
gem.email = ["dev@eger-andreas.de"]
|
12
|
+
gem.description = %q{JRuby language bindings for libsvm}
|
13
|
+
gem.summary = %q{basic wrapper around the java libsvm libary}
|
14
|
+
gem.homepage = "https://github.com/sch1zo/jrb-libsvm"
|
15
|
+
|
16
|
+
gem.files = `git ls-files`.split($/)
|
17
|
+
gem.executables = gem.files.grep(%r{^bin/}).map{ |f| File.basename(f) }
|
18
|
+
gem.test_files = gem.files.grep(%r{^(test|spec|features)/})
|
19
|
+
gem.require_paths = ["lib"]
|
20
|
+
|
21
|
+
# specify any dependencies here
|
22
|
+
gem.add_development_dependency('rspec', '>= 2.7.0')
|
23
|
+
end
|
data/lib/java/libsvm.jar
ADDED
Binary file
|
@@ -0,0 +1,97 @@
|
|
1
|
+
java_import 'java.io.DataOutputStream'
|
2
|
+
java_import 'java.io.ByteArrayOutputStream'
|
3
|
+
|
4
|
+
java_import 'java.io.StringReader'
|
5
|
+
java_import 'java.io.BufferedReader'
|
6
|
+
|
7
|
+
module Libsvm
|
8
|
+
class Model
|
9
|
+
class << self
|
10
|
+
def train(problem, parameter)
|
11
|
+
return Svm.svm_train(problem, parameter)
|
12
|
+
end
|
13
|
+
end
|
14
|
+
|
15
|
+
# Return the SVM problem type for this model
|
16
|
+
def svm_type
|
17
|
+
self.param.svm_type
|
18
|
+
end
|
19
|
+
|
20
|
+
# Return the kernel type for this model
|
21
|
+
def kernel_type
|
22
|
+
self.param.kernel_type
|
23
|
+
end
|
24
|
+
|
25
|
+
# Return the value of the degree parameter
|
26
|
+
def degree
|
27
|
+
self.param.degree
|
28
|
+
end
|
29
|
+
|
30
|
+
# Return the value of the gamma parameter
|
31
|
+
def gamma
|
32
|
+
self.param.gamma
|
33
|
+
end
|
34
|
+
|
35
|
+
# Return the value of the cost parameter
|
36
|
+
def cost
|
37
|
+
self.param.c
|
38
|
+
end
|
39
|
+
|
40
|
+
# Return the number of classes handled by this model.
|
41
|
+
def classes
|
42
|
+
self.nr_class
|
43
|
+
end
|
44
|
+
|
45
|
+
# Save model to given filename.
|
46
|
+
# Raises IOError on any error.
|
47
|
+
def save filename
|
48
|
+
Svm.svm_save_model(filename, self)
|
49
|
+
rescue e = java.io.IOException
|
50
|
+
raise IOError.new "Error in saving SVM model to file: #{e}"
|
51
|
+
end
|
52
|
+
|
53
|
+
# Serialize model and return a string
|
54
|
+
def serialize
|
55
|
+
stream = ByteArrayOutputStream.new
|
56
|
+
do_stream = DataOutputStream.new(stream)
|
57
|
+
Svm.svm_save_model(do_stream, self)
|
58
|
+
stream.to_s
|
59
|
+
rescue java.io.IOException
|
60
|
+
raise IOError.new "Error in serializing SVM model"
|
61
|
+
end
|
62
|
+
def to_s
|
63
|
+
serialize
|
64
|
+
end
|
65
|
+
|
66
|
+
# Load model from given filename.
|
67
|
+
# Raises IOError on any error.
|
68
|
+
def self.load filename
|
69
|
+
Svm.svm_load_model(filename)
|
70
|
+
rescue java.io.IOException
|
71
|
+
raise IOError.new "Error in loading SVM model from file"
|
72
|
+
end
|
73
|
+
|
74
|
+
# Load model from string.
|
75
|
+
def self.parse string
|
76
|
+
reader = BufferedReader.new(StringReader.new(string))
|
77
|
+
Svm.svm_load_model(reader)
|
78
|
+
rescue java.io.IOException
|
79
|
+
raise IOError.new "Error in loading SVM model from string"
|
80
|
+
end
|
81
|
+
|
82
|
+
def predict(example, &block)
|
83
|
+
if block.nil?
|
84
|
+
return Svm.svm_predict(self, example)
|
85
|
+
else
|
86
|
+
prediction, probabilities = predict_probability(example)
|
87
|
+
yield probabilities
|
88
|
+
return prediction
|
89
|
+
end
|
90
|
+
end
|
91
|
+
|
92
|
+
def predict_probability(example)
|
93
|
+
probabilities = Java::double[self.classes].new
|
94
|
+
return Svm.svm_predict_probability(self, example, probabilities), probabilities
|
95
|
+
end
|
96
|
+
end
|
97
|
+
end
|
@@ -0,0 +1,37 @@
|
|
1
|
+
module Libsvm
|
2
|
+
class Node
|
3
|
+
class << self
|
4
|
+
def features(*vargs)
|
5
|
+
array_of_nodes = []
|
6
|
+
if vargs.size == 1
|
7
|
+
if vargs.first.class == Array
|
8
|
+
vargs.first.each_with_index do |value, index|
|
9
|
+
array_of_nodes << Node.new(index.to_i, value.to_f)
|
10
|
+
end
|
11
|
+
elsif vargs.first.class == Hash
|
12
|
+
vargs.first.each do |index, value|
|
13
|
+
array_of_nodes << Node.new(index.to_i, value.to_f)
|
14
|
+
end
|
15
|
+
else
|
16
|
+
raise(ArgumentError.new("Node features need to be a Hash, Array or Floats"))
|
17
|
+
end
|
18
|
+
else
|
19
|
+
vargs.each_with_index do |value, index|
|
20
|
+
array_of_nodes << Node.new(index.to_i, value.to_f)
|
21
|
+
end
|
22
|
+
end
|
23
|
+
array_of_nodes
|
24
|
+
end
|
25
|
+
end
|
26
|
+
|
27
|
+
def initialize(index=0, value=0.0)
|
28
|
+
super()
|
29
|
+
self.index = index
|
30
|
+
self.value = value
|
31
|
+
end
|
32
|
+
|
33
|
+
def ==(other)
|
34
|
+
index == other.index && value == other.value
|
35
|
+
end
|
36
|
+
end
|
37
|
+
end
|
@@ -0,0 +1,66 @@
|
|
1
|
+
module Libsvm
|
2
|
+
class Parameter
|
3
|
+
alias :c :C
|
4
|
+
alias :c= :C=
|
5
|
+
|
6
|
+
def label_weights
|
7
|
+
Hash[self.weight_label.zip(self.weight)]
|
8
|
+
end
|
9
|
+
def label_weights=v
|
10
|
+
self.nr_weight = v.keys.size
|
11
|
+
self.weight_label = v.keys.to_java :int
|
12
|
+
self.weight = v.values.to_java :double
|
13
|
+
end
|
14
|
+
|
15
|
+
# Constructor sets up values of attributes based on provided map.
|
16
|
+
# Valid keys with their default values:
|
17
|
+
# * :svm_type = Parameter::C_SVC, for the type of SVM
|
18
|
+
# * :kernel_type = Parameter::LINEAR, for the type of kernel
|
19
|
+
# * :cost = 1.0, for the cost or C parameter
|
20
|
+
# * :gamma = 0.0, for the gamma parameter in kernel
|
21
|
+
# * :degree = 1, for polynomial kernel
|
22
|
+
# * :coef0 = 0.0, for polynomial/sigmoid kernels
|
23
|
+
# * :eps = 0.001, for stopping criterion
|
24
|
+
# * :nr_weight = 0, for C_SVC
|
25
|
+
# * :nu = 0.5, used for NU_SVC, ONE_CLASS and NU_SVR. Nu must be in (0,1]
|
26
|
+
# * :p = 0.1, used for EPSILON_SVR
|
27
|
+
# * :shrinking = 1, use the shrinking heuristics
|
28
|
+
# * :probability = 0, use the probability estimates
|
29
|
+
def initialize args={}
|
30
|
+
super()
|
31
|
+
self.svm_type = args.fetch(:svm_type, Parameter::C_SVC)
|
32
|
+
self.kernel_type = args.fetch(:kernel_type, Parameter::LINEAR)
|
33
|
+
self.C = args.fetch(:cost, 1.0)
|
34
|
+
self.gamma = args.fetch(:gamma, 0.0)
|
35
|
+
self.degree = args.fetch(:degree, 1)
|
36
|
+
self.coef0 = args.fetch(:coef0, 0.0)
|
37
|
+
self.eps = args.fetch(:eps, 0.001)
|
38
|
+
self.nr_weight = args.fetch(:nr_weight, 0)
|
39
|
+
self.nu = args.fetch(:nu, 0.5)
|
40
|
+
self.p = args.fetch(:p, 0.1)
|
41
|
+
self.shrinking = args.fetch(:shrinking, 1)
|
42
|
+
self.probability = args.fetch(:probability, 0)
|
43
|
+
self.cache_size = args.fetch(:cache_size, 1)
|
44
|
+
|
45
|
+
unless self.nu > 0.0 and self.nu <= 1.0
|
46
|
+
raise ArgumentError "Invalid value of nu #{self.nu}, should be in (0,1]"
|
47
|
+
end
|
48
|
+
end
|
49
|
+
end
|
50
|
+
SvmParameter = Parameter
|
51
|
+
|
52
|
+
module KernelType
|
53
|
+
LINEAR = Parameter::LINEAR
|
54
|
+
POLY = Parameter::POLY
|
55
|
+
RBF = Parameter::RBF
|
56
|
+
SIGMOID = Parameter::SIGMOID
|
57
|
+
PRECOMPUTED = Parameter::PRECOMPUTED
|
58
|
+
end
|
59
|
+
module SvmType
|
60
|
+
C_SVC = Parameter::C_SVC
|
61
|
+
NU_SVC = Parameter::NU_SVC
|
62
|
+
ONE_CLASS = Parameter::ONE_CLASS
|
63
|
+
EPSILON_SVR = Parameter::EPSILON_SVR
|
64
|
+
NU_SVR = Parameter::NU_SVR
|
65
|
+
end
|
66
|
+
end
|
@@ -0,0 +1,35 @@
|
|
1
|
+
module Libsvm
|
2
|
+
class Problem
|
3
|
+
def set_examples(labels, features)
|
4
|
+
unless features.size == labels.size
|
5
|
+
raise ArgumentError.new "Number of features must equal number of labels"
|
6
|
+
end
|
7
|
+
unless features.size > 0
|
8
|
+
raise ArgumentError.new "There must be at least one feature."
|
9
|
+
end
|
10
|
+
unless features.collect {|i| i.size}.min == features.collect {|i| i.size}.max
|
11
|
+
raise ArgumentError.new "All features must have the same size"
|
12
|
+
end
|
13
|
+
|
14
|
+
self.l = labels.size
|
15
|
+
# -- add in the training data
|
16
|
+
self.x = Node[features.size, features[0].size].new
|
17
|
+
features.each.with_index do |feature, i|
|
18
|
+
feature.each.with_index do |node, j|
|
19
|
+
self.x[i][j] = node
|
20
|
+
end
|
21
|
+
end
|
22
|
+
|
23
|
+
# -- add in the labels
|
24
|
+
self.y = Java::double[labels.size].new
|
25
|
+
labels.each.with_index do |label, i|
|
26
|
+
self.y[i] = label
|
27
|
+
end
|
28
|
+
|
29
|
+
return labels.size
|
30
|
+
end
|
31
|
+
def examples
|
32
|
+
return self.y, self.x
|
33
|
+
end
|
34
|
+
end
|
35
|
+
end
|
data/lib/jrb-libsvm.rb
ADDED
@@ -0,0 +1,31 @@
|
|
1
|
+
require_relative "jrb-libsvm/version"
|
2
|
+
require "java"
|
3
|
+
require_relative "java/libsvm"
|
4
|
+
|
5
|
+
module Libsvm
|
6
|
+
java_import "libsvm.Parameter"
|
7
|
+
java_import "libsvm.Model"
|
8
|
+
java_import "libsvm.Problem"
|
9
|
+
java_import "libsvm.Node"
|
10
|
+
java_import "libsvm.Svm"
|
11
|
+
|
12
|
+
module CoreExtensions
|
13
|
+
module Collection
|
14
|
+
def to_example
|
15
|
+
Node.features(self)
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
19
|
+
end
|
20
|
+
|
21
|
+
require_relative 'jrb-libsvm/parameter'
|
22
|
+
require_relative 'jrb-libsvm/model'
|
23
|
+
require_relative 'jrb-libsvm/node'
|
24
|
+
require_relative 'jrb-libsvm/problem'
|
25
|
+
|
26
|
+
class Hash
|
27
|
+
include Libsvm::CoreExtensions::Collection
|
28
|
+
end
|
29
|
+
class Array
|
30
|
+
include Libsvm::CoreExtensions::Collection
|
31
|
+
end
|
data/spec/model_spec.rb
ADDED
@@ -0,0 +1,119 @@
|
|
1
|
+
require "spec_helper"
|
2
|
+
|
3
|
+
describe Model do
|
4
|
+
def create_example
|
5
|
+
Node.features(0.2, 0.3, 0.4, 0.5)
|
6
|
+
end
|
7
|
+
|
8
|
+
def create_problem
|
9
|
+
problem = Problem.new
|
10
|
+
features = [Node.features([0.2,0.3,0.4,0.4]),
|
11
|
+
Node.features([0.1,0.5,0.1,0.9]),
|
12
|
+
Node.features([0.2,0.2,0.6,0.5]),
|
13
|
+
Node.features([0.3,0.1,0.5,0.9])]
|
14
|
+
problem.set_examples([1,2,1,2], features)
|
15
|
+
problem
|
16
|
+
end
|
17
|
+
|
18
|
+
def create_parameter
|
19
|
+
parameter = Parameter.new
|
20
|
+
parameter.cache_size = 50 # mb
|
21
|
+
parameter.eps = 0.01
|
22
|
+
parameter.c = 10
|
23
|
+
parameter.probability = 1
|
24
|
+
parameter
|
25
|
+
end
|
26
|
+
|
27
|
+
context "The class interface" do
|
28
|
+
before(:each) do
|
29
|
+
@problem = create_problem
|
30
|
+
@parameter = create_parameter
|
31
|
+
end
|
32
|
+
|
33
|
+
it "results from training on a problem under a certain parameter set" do
|
34
|
+
model = Model.train(@problem,@parameter)
|
35
|
+
model.should_not be_nil
|
36
|
+
end
|
37
|
+
end
|
38
|
+
|
39
|
+
context "A saved model" do
|
40
|
+
before(:each) do
|
41
|
+
@filename = "tmp/svm_model.model"
|
42
|
+
model = Model.train(create_problem, create_parameter)
|
43
|
+
model.save(@filename)
|
44
|
+
@model_string = model.to_s
|
45
|
+
end
|
46
|
+
|
47
|
+
it "can be loaded from a file" do
|
48
|
+
model = Model.load(@filename)
|
49
|
+
model.should_not be_nil
|
50
|
+
end
|
51
|
+
|
52
|
+
it "can be loaded from a string" do
|
53
|
+
model = Model.parse @model_string
|
54
|
+
model.should_not be_nil
|
55
|
+
end
|
56
|
+
|
57
|
+
it "should do the same for load and load_from_string" do
|
58
|
+
model = Model.load @filename
|
59
|
+
model2 = Model.parse @model_string
|
60
|
+
model.serialize.should == model2.serialize
|
61
|
+
end
|
62
|
+
|
63
|
+
after(:each) do
|
64
|
+
File.delete(@filename) rescue nil
|
65
|
+
end
|
66
|
+
end
|
67
|
+
|
68
|
+
context "An Libsvm model" do
|
69
|
+
before(:each) do
|
70
|
+
@problem = create_problem
|
71
|
+
@parameter = create_parameter
|
72
|
+
@model = Model.train(@problem, @parameter)
|
73
|
+
@file_path = "tmp/svm_model.model"
|
74
|
+
end
|
75
|
+
after(:each) do
|
76
|
+
File.delete(@file_path) if File.exists?(@file_path)
|
77
|
+
end
|
78
|
+
|
79
|
+
it "can be saved to a file" do
|
80
|
+
@model.save(@file_path)
|
81
|
+
File.exist?(@file_path).should be_true
|
82
|
+
end
|
83
|
+
|
84
|
+
it "can be serialized to a string" do
|
85
|
+
@model.serialize.should_not be_empty
|
86
|
+
end
|
87
|
+
|
88
|
+
it "should generate the same text for serialize and save" do
|
89
|
+
@model.save(@file_path)
|
90
|
+
@model.serialize.should == IO.read(@file_path)
|
91
|
+
end
|
92
|
+
|
93
|
+
it "can be asked for it's svm_type" do
|
94
|
+
@model.svm_type.should == Parameter::C_SVC
|
95
|
+
end
|
96
|
+
|
97
|
+
it "can be asked for it's number of classes (aka. labels)" do
|
98
|
+
@model.classes.should == 2
|
99
|
+
end
|
100
|
+
|
101
|
+
it "can predict" do
|
102
|
+
prediction = @model.predict(create_example)
|
103
|
+
prediction.should_not be_nil
|
104
|
+
end
|
105
|
+
it "can predict probability" do
|
106
|
+
prediction, probabilities = @model.predict_probability(create_example)
|
107
|
+
prediction.should_not be_nil
|
108
|
+
probabilities.should have(@model.classes).items
|
109
|
+
probabilities.each { |e| e.should_not be_nil }
|
110
|
+
end
|
111
|
+
it "can predict with block" do
|
112
|
+
prediction = @model.predict(create_example) do |probabilities|
|
113
|
+
probabilities.should be_all { |p| p.kind_of? Float }
|
114
|
+
probabilities.count.should == @model.classes
|
115
|
+
end
|
116
|
+
prediction.should be_a Float
|
117
|
+
end
|
118
|
+
end
|
119
|
+
end
|
data/spec/node_spec.rb
ADDED
@@ -0,0 +1,62 @@
|
|
1
|
+
require "spec_helper"
|
2
|
+
|
3
|
+
describe Node do
|
4
|
+
context "construction" do
|
5
|
+
it "using the properties" do
|
6
|
+
n = Node.new
|
7
|
+
n.index = 11
|
8
|
+
n.value = 0.11
|
9
|
+
n.index.should == 11
|
10
|
+
n.value.should be_within(0.0001).of(0.11)
|
11
|
+
end
|
12
|
+
|
13
|
+
it "using the constructor parameters" do
|
14
|
+
n = Node.new(14, 0.14)
|
15
|
+
n.index.should == 14
|
16
|
+
n.value.should be_within(0.0001).of(0.14)
|
17
|
+
end
|
18
|
+
end
|
19
|
+
|
20
|
+
context "inner workings" do
|
21
|
+
let(:node) {Node.new}
|
22
|
+
|
23
|
+
it "can be created" do
|
24
|
+
node.should_not be_nil
|
25
|
+
end
|
26
|
+
|
27
|
+
it "does not segfault on setting properties" do
|
28
|
+
node.index = 99
|
29
|
+
node.index.should == 99
|
30
|
+
node.value = 3.141
|
31
|
+
node.value.should be_within(0.00001).of(3.141)
|
32
|
+
end
|
33
|
+
|
34
|
+
it "has inited properties" do
|
35
|
+
node.index.should == 0
|
36
|
+
node.value.should be_within(0.00001).of(0)
|
37
|
+
end
|
38
|
+
|
39
|
+
it "class can create nodes from an array" do
|
40
|
+
ary = Node.features([0.1, 0.2, 0.3, 0.4, 0.5])
|
41
|
+
ary.map {|n| n.class.should == Node}
|
42
|
+
ary.map {|n| n.value }.should == [0.1, 0.2, 0.3, 0.4, 0.5]
|
43
|
+
end
|
44
|
+
|
45
|
+
it "class can create nodes from variable parameters" do
|
46
|
+
ary = Node.features(0.1, 0.2, 0.3, 0.4, 0.5)
|
47
|
+
ary.map {|n| Node.should === n }
|
48
|
+
ary.map {|n| n.value }.should == [0.1, 0.2, 0.3, 0.4, 0.5]
|
49
|
+
end
|
50
|
+
|
51
|
+
it "class can create nodes from hash" do
|
52
|
+
ary = Node.features(3=>0.3, 5=>0.5, 6=>0.6, 10=>1.0)
|
53
|
+
ary.map {|n| n.class.should == Node}
|
54
|
+
ary.map {|n| n.value }.sort.should == [0.3, 0.5, 0.6, 1.0]
|
55
|
+
ary.map {|n| n.index }.sort.should == [3, 5, 6, 10]
|
56
|
+
end
|
57
|
+
|
58
|
+
it "implements a value-like equality, not identity-notion" do
|
59
|
+
Node.new(1, 0.1).should == Node.new(1, 0.1)
|
60
|
+
end
|
61
|
+
end
|
62
|
+
end
|
@@ -0,0 +1,79 @@
|
|
1
|
+
require "spec_helper"
|
2
|
+
|
3
|
+
describe Parameter do
|
4
|
+
before do
|
5
|
+
@p = Libsvm::SvmParameter.new
|
6
|
+
end
|
7
|
+
it "can be created with a constructor" do
|
8
|
+
->{Libsvm::SvmParameter.new(svm_type: Libsvm::SvmType::C_SVC, cost: 23, gamma: 65)}.should_not raise_error
|
9
|
+
end
|
10
|
+
it "int svm_type" do
|
11
|
+
SvmType::C_SVC.should == 0
|
12
|
+
@p.svm_type = SvmType::C_SVC
|
13
|
+
@p.svm_type.should == SvmType::C_SVC
|
14
|
+
end
|
15
|
+
|
16
|
+
it "int kernel_type" do
|
17
|
+
KernelType::RBF.should == 2
|
18
|
+
@p.kernel_type = KernelType::RBF
|
19
|
+
@p.kernel_type.should == KernelType::RBF
|
20
|
+
end
|
21
|
+
|
22
|
+
it "int degree" do
|
23
|
+
@p.degree = 99
|
24
|
+
@p.degree.should == 99
|
25
|
+
end
|
26
|
+
|
27
|
+
it "double gamma" do
|
28
|
+
@p.gamma = 0.33
|
29
|
+
@p.gamma.should == 0.33
|
30
|
+
end
|
31
|
+
|
32
|
+
it "double coef0" do
|
33
|
+
@p.coef0 = 0.99
|
34
|
+
@p.coef0.should == 0.99
|
35
|
+
end
|
36
|
+
|
37
|
+
it "double cache_size" do
|
38
|
+
@p.cache_size = 0.77
|
39
|
+
@p.cache_size.should == 0.77
|
40
|
+
end
|
41
|
+
|
42
|
+
it "double eps" do
|
43
|
+
@p.eps = 0.111
|
44
|
+
@p.eps.should == 0.111
|
45
|
+
@p.eps = 0.112
|
46
|
+
@p.eps.should == 0.112
|
47
|
+
end
|
48
|
+
|
49
|
+
it "double C" do
|
50
|
+
@p.c = 3.141
|
51
|
+
@p.c.should == 3.141
|
52
|
+
end
|
53
|
+
|
54
|
+
it "can set and read weights (weight, weight_label, nr_weight members from struct)" do
|
55
|
+
@p.label_weights = {1=> 1.2, 3=>0.2, 5=>0.888}
|
56
|
+
@p.label_weights.should == {1=> 1.2, 3=>0.2, 5=>0.888}
|
57
|
+
end
|
58
|
+
|
59
|
+
|
60
|
+
it "double nu" do
|
61
|
+
@p.nu = 1.1
|
62
|
+
@p.nu.should == 1.1
|
63
|
+
end
|
64
|
+
|
65
|
+
it "double p" do
|
66
|
+
@p.p = 0.123
|
67
|
+
@p.p.should == 0.123
|
68
|
+
end
|
69
|
+
|
70
|
+
it "int shrinking" do
|
71
|
+
@p.shrinking = 22
|
72
|
+
@p.shrinking.should == 22
|
73
|
+
end
|
74
|
+
|
75
|
+
it "int probability" do
|
76
|
+
@p.probability = 35
|
77
|
+
@p.probability.should == 35
|
78
|
+
end
|
79
|
+
end
|
@@ -0,0 +1,37 @@
|
|
1
|
+
require 'spec_helper'
|
2
|
+
|
3
|
+
describe Problem do
|
4
|
+
before(:each) do
|
5
|
+
@problem = Problem.new
|
6
|
+
@features = [ Node.features(0.2,0.3,0.4,0.4),
|
7
|
+
Node.features(0.1,0.5,0.1,0.9),
|
8
|
+
Node.features(0.2,0.2,0.6,0.5),
|
9
|
+
Node.features(0.3,0.1,0.5,0.9) ]
|
10
|
+
end
|
11
|
+
|
12
|
+
it "examples get stored and retrieved" do
|
13
|
+
@problem.set_examples([1,2,1,2], @features)
|
14
|
+
labels, examples = @problem.examples
|
15
|
+
labels.size.should == 4
|
16
|
+
examples.size.should == 4
|
17
|
+
examples.map {|x|x.size}.should == [4,4,4,4]
|
18
|
+
examples.first.map {|node| node.index}.should == [0,1,2,3]
|
19
|
+
examples.first.map {|node| node.value}.should == [0.2,0.3,0.4,0.4]
|
20
|
+
end
|
21
|
+
|
22
|
+
it "can be populated" do
|
23
|
+
examples = [Node.features(0.2,0.3,0.4,0.4),
|
24
|
+
Node.features(0.1,0.5,0.1,0.9),
|
25
|
+
Node.features(0.2,0.2,0.6,0.5),
|
26
|
+
Node.features(0.3,0.1,0.5,0.9)]
|
27
|
+
->{@problem.set_examples([1,2,1,2], examples)}.should_not raise_error
|
28
|
+
end
|
29
|
+
|
30
|
+
it "can be set twice over" do
|
31
|
+
features = [Node.features(0.2, 0.3, 0.4, 0.4), Node.features(0.3,0.1,0.5,0.9)]
|
32
|
+
@problem.set_examples([1,2], features)
|
33
|
+
features = [Node.features(0.2, 0.3, 0.4, 0.4), Node.features(0.3,0.1,0.5,0.9)]
|
34
|
+
->{@problem.set_examples([8,2], features)}.should_not raise_error
|
35
|
+
end
|
36
|
+
|
37
|
+
end
|
data/spec/spec_helper.rb
ADDED
data/spec/usage_spec.rb
ADDED
@@ -0,0 +1,47 @@
|
|
1
|
+
require 'spec_helper'
|
2
|
+
|
3
|
+
describe "Basic usage" do
|
4
|
+
before do
|
5
|
+
@problem = Problem.new
|
6
|
+
@parameter = Parameter.new
|
7
|
+
@parameter.cache_size = 1 # mb
|
8
|
+
|
9
|
+
# "eps is the stopping criterion (we usually use 0.00001 in nu-SVC,
|
10
|
+
# 0.001 in others)." (from README)
|
11
|
+
@parameter.eps = 0.001
|
12
|
+
|
13
|
+
@parameter.c = 10
|
14
|
+
end
|
15
|
+
|
16
|
+
it "has a nice API" do
|
17
|
+
example = {11 => 0.11, 21 => 0.21, 101 => 0.99 }.to_example
|
18
|
+
example.should == Node.features({11 => 0.11, 21 => 0.21, 101 => 0.99 })
|
19
|
+
end
|
20
|
+
|
21
|
+
it "is as in [PCI,217]" do
|
22
|
+
examples = [ [1,0,1], [-1,0,-1] ].map {|ary| Node.features(ary) }
|
23
|
+
labels = [1, -1]
|
24
|
+
@problem.set_examples(labels, examples)
|
25
|
+
|
26
|
+
model = Model.train(@problem, @parameter)
|
27
|
+
|
28
|
+
pred = model.predict(Node.features(1, 1, 1))
|
29
|
+
pred.should == 1.0
|
30
|
+
|
31
|
+
pred = model.predict(Node.features(-1, 1, -1))
|
32
|
+
pred.should == -1.0
|
33
|
+
|
34
|
+
pred = model.predict(Node.features(-1, 55, -1))
|
35
|
+
pred.should == -1.0
|
36
|
+
end
|
37
|
+
|
38
|
+
it "kernel parameter use" do
|
39
|
+
@parameter.kernel_type = Parameter::RBF
|
40
|
+
examples = [[1, 2, 3], [-2, -2, -2]].map {|ary| Node.features(ary) }
|
41
|
+
@problem.set_examples([1, 2], examples)
|
42
|
+
|
43
|
+
model = Model.train(@problem, @parameter)
|
44
|
+
|
45
|
+
model.predict(Node.features(1, 2, 3)).should == 2
|
46
|
+
end
|
47
|
+
end
|
data/tmp/.gitkeep
ADDED
File without changes
|