jrb-libsvm 0.1.2-java

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,23 @@
1
+ # -*- encoding: utf-8 -*-
2
+ lib = File.expand_path('../lib', __FILE__)
3
+ $LOAD_PATH.unshift(lib) unless $LOAD_PATH.include?(lib)
4
+ require 'jrb-libsvm/version'
5
+
6
+ Gem::Specification.new do |gem|
7
+ gem.name = "jrb-libsvm"
8
+ gem.version = JrbLibsvm::VERSION
9
+ gem.platform = 'java'
10
+ gem.authors = ["Andreas Eger"]
11
+ gem.email = ["dev@eger-andreas.de"]
12
+ gem.description = %q{JRuby language bindings for libsvm}
13
+ gem.summary = %q{basic wrapper around the java libsvm libary}
14
+ gem.homepage = "https://github.com/sch1zo/jrb-libsvm"
15
+
16
+ gem.files = `git ls-files`.split($/)
17
+ gem.executables = gem.files.grep(%r{^bin/}).map{ |f| File.basename(f) }
18
+ gem.test_files = gem.files.grep(%r{^(test|spec|features)/})
19
+ gem.require_paths = ["lib"]
20
+
21
+ # specify any dependencies here
22
+ gem.add_development_dependency('rspec', '>= 2.7.0')
23
+ end
Binary file
@@ -0,0 +1,97 @@
1
+ java_import 'java.io.DataOutputStream'
2
+ java_import 'java.io.ByteArrayOutputStream'
3
+
4
+ java_import 'java.io.StringReader'
5
+ java_import 'java.io.BufferedReader'
6
+
7
+ module Libsvm
8
+ class Model
9
+ class << self
10
+ def train(problem, parameter)
11
+ return Svm.svm_train(problem, parameter)
12
+ end
13
+ end
14
+
15
+ # Return the SVM problem type for this model
16
+ def svm_type
17
+ self.param.svm_type
18
+ end
19
+
20
+ # Return the kernel type for this model
21
+ def kernel_type
22
+ self.param.kernel_type
23
+ end
24
+
25
+ # Return the value of the degree parameter
26
+ def degree
27
+ self.param.degree
28
+ end
29
+
30
+ # Return the value of the gamma parameter
31
+ def gamma
32
+ self.param.gamma
33
+ end
34
+
35
+ # Return the value of the cost parameter
36
+ def cost
37
+ self.param.c
38
+ end
39
+
40
+ # Return the number of classes handled by this model.
41
+ def classes
42
+ self.nr_class
43
+ end
44
+
45
+ # Save model to given filename.
46
+ # Raises IOError on any error.
47
+ def save filename
48
+ Svm.svm_save_model(filename, self)
49
+ rescue e = java.io.IOException
50
+ raise IOError.new "Error in saving SVM model to file: #{e}"
51
+ end
52
+
53
+ # Serialize model and return a string
54
+ def serialize
55
+ stream = ByteArrayOutputStream.new
56
+ do_stream = DataOutputStream.new(stream)
57
+ Svm.svm_save_model(do_stream, self)
58
+ stream.to_s
59
+ rescue java.io.IOException
60
+ raise IOError.new "Error in serializing SVM model"
61
+ end
62
+ def to_s
63
+ serialize
64
+ end
65
+
66
+ # Load model from given filename.
67
+ # Raises IOError on any error.
68
+ def self.load filename
69
+ Svm.svm_load_model(filename)
70
+ rescue java.io.IOException
71
+ raise IOError.new "Error in loading SVM model from file"
72
+ end
73
+
74
+ # Load model from string.
75
+ def self.parse string
76
+ reader = BufferedReader.new(StringReader.new(string))
77
+ Svm.svm_load_model(reader)
78
+ rescue java.io.IOException
79
+ raise IOError.new "Error in loading SVM model from string"
80
+ end
81
+
82
+ def predict(example, &block)
83
+ if block.nil?
84
+ return Svm.svm_predict(self, example)
85
+ else
86
+ prediction, probabilities = predict_probability(example)
87
+ yield probabilities
88
+ return prediction
89
+ end
90
+ end
91
+
92
+ def predict_probability(example)
93
+ probabilities = Java::double[self.classes].new
94
+ return Svm.svm_predict_probability(self, example, probabilities), probabilities
95
+ end
96
+ end
97
+ end
@@ -0,0 +1,37 @@
1
+ module Libsvm
2
+ class Node
3
+ class << self
4
+ def features(*vargs)
5
+ array_of_nodes = []
6
+ if vargs.size == 1
7
+ if vargs.first.class == Array
8
+ vargs.first.each_with_index do |value, index|
9
+ array_of_nodes << Node.new(index.to_i, value.to_f)
10
+ end
11
+ elsif vargs.first.class == Hash
12
+ vargs.first.each do |index, value|
13
+ array_of_nodes << Node.new(index.to_i, value.to_f)
14
+ end
15
+ else
16
+ raise(ArgumentError.new("Node features need to be a Hash, Array or Floats"))
17
+ end
18
+ else
19
+ vargs.each_with_index do |value, index|
20
+ array_of_nodes << Node.new(index.to_i, value.to_f)
21
+ end
22
+ end
23
+ array_of_nodes
24
+ end
25
+ end
26
+
27
+ def initialize(index=0, value=0.0)
28
+ super()
29
+ self.index = index
30
+ self.value = value
31
+ end
32
+
33
+ def ==(other)
34
+ index == other.index && value == other.value
35
+ end
36
+ end
37
+ end
@@ -0,0 +1,66 @@
1
+ module Libsvm
2
+ class Parameter
3
+ alias :c :C
4
+ alias :c= :C=
5
+
6
+ def label_weights
7
+ Hash[self.weight_label.zip(self.weight)]
8
+ end
9
+ def label_weights=v
10
+ self.nr_weight = v.keys.size
11
+ self.weight_label = v.keys.to_java :int
12
+ self.weight = v.values.to_java :double
13
+ end
14
+
15
+ # Constructor sets up values of attributes based on provided map.
16
+ # Valid keys with their default values:
17
+ # * :svm_type = Parameter::C_SVC, for the type of SVM
18
+ # * :kernel_type = Parameter::LINEAR, for the type of kernel
19
+ # * :cost = 1.0, for the cost or C parameter
20
+ # * :gamma = 0.0, for the gamma parameter in kernel
21
+ # * :degree = 1, for polynomial kernel
22
+ # * :coef0 = 0.0, for polynomial/sigmoid kernels
23
+ # * :eps = 0.001, for stopping criterion
24
+ # * :nr_weight = 0, for C_SVC
25
+ # * :nu = 0.5, used for NU_SVC, ONE_CLASS and NU_SVR. Nu must be in (0,1]
26
+ # * :p = 0.1, used for EPSILON_SVR
27
+ # * :shrinking = 1, use the shrinking heuristics
28
+ # * :probability = 0, use the probability estimates
29
+ def initialize args={}
30
+ super()
31
+ self.svm_type = args.fetch(:svm_type, Parameter::C_SVC)
32
+ self.kernel_type = args.fetch(:kernel_type, Parameter::LINEAR)
33
+ self.C = args.fetch(:cost, 1.0)
34
+ self.gamma = args.fetch(:gamma, 0.0)
35
+ self.degree = args.fetch(:degree, 1)
36
+ self.coef0 = args.fetch(:coef0, 0.0)
37
+ self.eps = args.fetch(:eps, 0.001)
38
+ self.nr_weight = args.fetch(:nr_weight, 0)
39
+ self.nu = args.fetch(:nu, 0.5)
40
+ self.p = args.fetch(:p, 0.1)
41
+ self.shrinking = args.fetch(:shrinking, 1)
42
+ self.probability = args.fetch(:probability, 0)
43
+ self.cache_size = args.fetch(:cache_size, 1)
44
+
45
+ unless self.nu > 0.0 and self.nu <= 1.0
46
+ raise ArgumentError "Invalid value of nu #{self.nu}, should be in (0,1]"
47
+ end
48
+ end
49
+ end
50
+ SvmParameter = Parameter
51
+
52
+ module KernelType
53
+ LINEAR = Parameter::LINEAR
54
+ POLY = Parameter::POLY
55
+ RBF = Parameter::RBF
56
+ SIGMOID = Parameter::SIGMOID
57
+ PRECOMPUTED = Parameter::PRECOMPUTED
58
+ end
59
+ module SvmType
60
+ C_SVC = Parameter::C_SVC
61
+ NU_SVC = Parameter::NU_SVC
62
+ ONE_CLASS = Parameter::ONE_CLASS
63
+ EPSILON_SVR = Parameter::EPSILON_SVR
64
+ NU_SVR = Parameter::NU_SVR
65
+ end
66
+ end
@@ -0,0 +1,35 @@
1
+ module Libsvm
2
+ class Problem
3
+ def set_examples(labels, features)
4
+ unless features.size == labels.size
5
+ raise ArgumentError.new "Number of features must equal number of labels"
6
+ end
7
+ unless features.size > 0
8
+ raise ArgumentError.new "There must be at least one feature."
9
+ end
10
+ unless features.collect {|i| i.size}.min == features.collect {|i| i.size}.max
11
+ raise ArgumentError.new "All features must have the same size"
12
+ end
13
+
14
+ self.l = labels.size
15
+ # -- add in the training data
16
+ self.x = Node[features.size, features[0].size].new
17
+ features.each.with_index do |feature, i|
18
+ feature.each.with_index do |node, j|
19
+ self.x[i][j] = node
20
+ end
21
+ end
22
+
23
+ # -- add in the labels
24
+ self.y = Java::double[labels.size].new
25
+ labels.each.with_index do |label, i|
26
+ self.y[i] = label
27
+ end
28
+
29
+ return labels.size
30
+ end
31
+ def examples
32
+ return self.y, self.x
33
+ end
34
+ end
35
+ end
@@ -0,0 +1,3 @@
1
+ module JrbLibsvm
2
+ VERSION = "0.1.2"
3
+ end
data/lib/jrb-libsvm.rb ADDED
@@ -0,0 +1,31 @@
1
+ require_relative "jrb-libsvm/version"
2
+ require "java"
3
+ require_relative "java/libsvm"
4
+
5
+ module Libsvm
6
+ java_import "libsvm.Parameter"
7
+ java_import "libsvm.Model"
8
+ java_import "libsvm.Problem"
9
+ java_import "libsvm.Node"
10
+ java_import "libsvm.Svm"
11
+
12
+ module CoreExtensions
13
+ module Collection
14
+ def to_example
15
+ Node.features(self)
16
+ end
17
+ end
18
+ end
19
+ end
20
+
21
+ require_relative 'jrb-libsvm/parameter'
22
+ require_relative 'jrb-libsvm/model'
23
+ require_relative 'jrb-libsvm/node'
24
+ require_relative 'jrb-libsvm/problem'
25
+
26
+ class Hash
27
+ include Libsvm::CoreExtensions::Collection
28
+ end
29
+ class Array
30
+ include Libsvm::CoreExtensions::Collection
31
+ end
@@ -0,0 +1,119 @@
1
+ require "spec_helper"
2
+
3
+ describe Model do
4
+ def create_example
5
+ Node.features(0.2, 0.3, 0.4, 0.5)
6
+ end
7
+
8
+ def create_problem
9
+ problem = Problem.new
10
+ features = [Node.features([0.2,0.3,0.4,0.4]),
11
+ Node.features([0.1,0.5,0.1,0.9]),
12
+ Node.features([0.2,0.2,0.6,0.5]),
13
+ Node.features([0.3,0.1,0.5,0.9])]
14
+ problem.set_examples([1,2,1,2], features)
15
+ problem
16
+ end
17
+
18
+ def create_parameter
19
+ parameter = Parameter.new
20
+ parameter.cache_size = 50 # mb
21
+ parameter.eps = 0.01
22
+ parameter.c = 10
23
+ parameter.probability = 1
24
+ parameter
25
+ end
26
+
27
+ context "The class interface" do
28
+ before(:each) do
29
+ @problem = create_problem
30
+ @parameter = create_parameter
31
+ end
32
+
33
+ it "results from training on a problem under a certain parameter set" do
34
+ model = Model.train(@problem,@parameter)
35
+ model.should_not be_nil
36
+ end
37
+ end
38
+
39
+ context "A saved model" do
40
+ before(:each) do
41
+ @filename = "tmp/svm_model.model"
42
+ model = Model.train(create_problem, create_parameter)
43
+ model.save(@filename)
44
+ @model_string = model.to_s
45
+ end
46
+
47
+ it "can be loaded from a file" do
48
+ model = Model.load(@filename)
49
+ model.should_not be_nil
50
+ end
51
+
52
+ it "can be loaded from a string" do
53
+ model = Model.parse @model_string
54
+ model.should_not be_nil
55
+ end
56
+
57
+ it "should do the same for load and load_from_string" do
58
+ model = Model.load @filename
59
+ model2 = Model.parse @model_string
60
+ model.serialize.should == model2.serialize
61
+ end
62
+
63
+ after(:each) do
64
+ File.delete(@filename) rescue nil
65
+ end
66
+ end
67
+
68
+ context "An Libsvm model" do
69
+ before(:each) do
70
+ @problem = create_problem
71
+ @parameter = create_parameter
72
+ @model = Model.train(@problem, @parameter)
73
+ @file_path = "tmp/svm_model.model"
74
+ end
75
+ after(:each) do
76
+ File.delete(@file_path) if File.exists?(@file_path)
77
+ end
78
+
79
+ it "can be saved to a file" do
80
+ @model.save(@file_path)
81
+ File.exist?(@file_path).should be_true
82
+ end
83
+
84
+ it "can be serialized to a string" do
85
+ @model.serialize.should_not be_empty
86
+ end
87
+
88
+ it "should generate the same text for serialize and save" do
89
+ @model.save(@file_path)
90
+ @model.serialize.should == IO.read(@file_path)
91
+ end
92
+
93
+ it "can be asked for it's svm_type" do
94
+ @model.svm_type.should == Parameter::C_SVC
95
+ end
96
+
97
+ it "can be asked for it's number of classes (aka. labels)" do
98
+ @model.classes.should == 2
99
+ end
100
+
101
+ it "can predict" do
102
+ prediction = @model.predict(create_example)
103
+ prediction.should_not be_nil
104
+ end
105
+ it "can predict probability" do
106
+ prediction, probabilities = @model.predict_probability(create_example)
107
+ prediction.should_not be_nil
108
+ probabilities.should have(@model.classes).items
109
+ probabilities.each { |e| e.should_not be_nil }
110
+ end
111
+ it "can predict with block" do
112
+ prediction = @model.predict(create_example) do |probabilities|
113
+ probabilities.should be_all { |p| p.kind_of? Float }
114
+ probabilities.count.should == @model.classes
115
+ end
116
+ prediction.should be_a Float
117
+ end
118
+ end
119
+ end
data/spec/node_spec.rb ADDED
@@ -0,0 +1,62 @@
1
+ require "spec_helper"
2
+
3
+ describe Node do
4
+ context "construction" do
5
+ it "using the properties" do
6
+ n = Node.new
7
+ n.index = 11
8
+ n.value = 0.11
9
+ n.index.should == 11
10
+ n.value.should be_within(0.0001).of(0.11)
11
+ end
12
+
13
+ it "using the constructor parameters" do
14
+ n = Node.new(14, 0.14)
15
+ n.index.should == 14
16
+ n.value.should be_within(0.0001).of(0.14)
17
+ end
18
+ end
19
+
20
+ context "inner workings" do
21
+ let(:node) {Node.new}
22
+
23
+ it "can be created" do
24
+ node.should_not be_nil
25
+ end
26
+
27
+ it "does not segfault on setting properties" do
28
+ node.index = 99
29
+ node.index.should == 99
30
+ node.value = 3.141
31
+ node.value.should be_within(0.00001).of(3.141)
32
+ end
33
+
34
+ it "has inited properties" do
35
+ node.index.should == 0
36
+ node.value.should be_within(0.00001).of(0)
37
+ end
38
+
39
+ it "class can create nodes from an array" do
40
+ ary = Node.features([0.1, 0.2, 0.3, 0.4, 0.5])
41
+ ary.map {|n| n.class.should == Node}
42
+ ary.map {|n| n.value }.should == [0.1, 0.2, 0.3, 0.4, 0.5]
43
+ end
44
+
45
+ it "class can create nodes from variable parameters" do
46
+ ary = Node.features(0.1, 0.2, 0.3, 0.4, 0.5)
47
+ ary.map {|n| Node.should === n }
48
+ ary.map {|n| n.value }.should == [0.1, 0.2, 0.3, 0.4, 0.5]
49
+ end
50
+
51
+ it "class can create nodes from hash" do
52
+ ary = Node.features(3=>0.3, 5=>0.5, 6=>0.6, 10=>1.0)
53
+ ary.map {|n| n.class.should == Node}
54
+ ary.map {|n| n.value }.sort.should == [0.3, 0.5, 0.6, 1.0]
55
+ ary.map {|n| n.index }.sort.should == [3, 5, 6, 10]
56
+ end
57
+
58
+ it "implements a value-like equality, not identity-notion" do
59
+ Node.new(1, 0.1).should == Node.new(1, 0.1)
60
+ end
61
+ end
62
+ end
@@ -0,0 +1,79 @@
1
+ require "spec_helper"
2
+
3
+ describe Parameter do
4
+ before do
5
+ @p = Libsvm::SvmParameter.new
6
+ end
7
+ it "can be created with a constructor" do
8
+ ->{Libsvm::SvmParameter.new(svm_type: Libsvm::SvmType::C_SVC, cost: 23, gamma: 65)}.should_not raise_error
9
+ end
10
+ it "int svm_type" do
11
+ SvmType::C_SVC.should == 0
12
+ @p.svm_type = SvmType::C_SVC
13
+ @p.svm_type.should == SvmType::C_SVC
14
+ end
15
+
16
+ it "int kernel_type" do
17
+ KernelType::RBF.should == 2
18
+ @p.kernel_type = KernelType::RBF
19
+ @p.kernel_type.should == KernelType::RBF
20
+ end
21
+
22
+ it "int degree" do
23
+ @p.degree = 99
24
+ @p.degree.should == 99
25
+ end
26
+
27
+ it "double gamma" do
28
+ @p.gamma = 0.33
29
+ @p.gamma.should == 0.33
30
+ end
31
+
32
+ it "double coef0" do
33
+ @p.coef0 = 0.99
34
+ @p.coef0.should == 0.99
35
+ end
36
+
37
+ it "double cache_size" do
38
+ @p.cache_size = 0.77
39
+ @p.cache_size.should == 0.77
40
+ end
41
+
42
+ it "double eps" do
43
+ @p.eps = 0.111
44
+ @p.eps.should == 0.111
45
+ @p.eps = 0.112
46
+ @p.eps.should == 0.112
47
+ end
48
+
49
+ it "double C" do
50
+ @p.c = 3.141
51
+ @p.c.should == 3.141
52
+ end
53
+
54
+ it "can set and read weights (weight, weight_label, nr_weight members from struct)" do
55
+ @p.label_weights = {1=> 1.2, 3=>0.2, 5=>0.888}
56
+ @p.label_weights.should == {1=> 1.2, 3=>0.2, 5=>0.888}
57
+ end
58
+
59
+
60
+ it "double nu" do
61
+ @p.nu = 1.1
62
+ @p.nu.should == 1.1
63
+ end
64
+
65
+ it "double p" do
66
+ @p.p = 0.123
67
+ @p.p.should == 0.123
68
+ end
69
+
70
+ it "int shrinking" do
71
+ @p.shrinking = 22
72
+ @p.shrinking.should == 22
73
+ end
74
+
75
+ it "int probability" do
76
+ @p.probability = 35
77
+ @p.probability.should == 35
78
+ end
79
+ end
@@ -0,0 +1,37 @@
1
+ require 'spec_helper'
2
+
3
+ describe Problem do
4
+ before(:each) do
5
+ @problem = Problem.new
6
+ @features = [ Node.features(0.2,0.3,0.4,0.4),
7
+ Node.features(0.1,0.5,0.1,0.9),
8
+ Node.features(0.2,0.2,0.6,0.5),
9
+ Node.features(0.3,0.1,0.5,0.9) ]
10
+ end
11
+
12
+ it "examples get stored and retrieved" do
13
+ @problem.set_examples([1,2,1,2], @features)
14
+ labels, examples = @problem.examples
15
+ labels.size.should == 4
16
+ examples.size.should == 4
17
+ examples.map {|x|x.size}.should == [4,4,4,4]
18
+ examples.first.map {|node| node.index}.should == [0,1,2,3]
19
+ examples.first.map {|node| node.value}.should == [0.2,0.3,0.4,0.4]
20
+ end
21
+
22
+ it "can be populated" do
23
+ examples = [Node.features(0.2,0.3,0.4,0.4),
24
+ Node.features(0.1,0.5,0.1,0.9),
25
+ Node.features(0.2,0.2,0.6,0.5),
26
+ Node.features(0.3,0.1,0.5,0.9)]
27
+ ->{@problem.set_examples([1,2,1,2], examples)}.should_not raise_error
28
+ end
29
+
30
+ it "can be set twice over" do
31
+ features = [Node.features(0.2, 0.3, 0.4, 0.4), Node.features(0.3,0.1,0.5,0.9)]
32
+ @problem.set_examples([1,2], features)
33
+ features = [Node.features(0.2, 0.3, 0.4, 0.4), Node.features(0.3,0.1,0.5,0.9)]
34
+ ->{@problem.set_examples([8,2], features)}.should_not raise_error
35
+ end
36
+
37
+ end
@@ -0,0 +1,9 @@
1
+ require 'bundler'
2
+ Bundler.setup
3
+ Bundler.require(:default, :test)
4
+ require 'jrb-libsvm'
5
+
6
+ include Libsvm
7
+
8
+ RSpec.configure do |config|
9
+ end
@@ -0,0 +1,47 @@
1
+ require 'spec_helper'
2
+
3
+ describe "Basic usage" do
4
+ before do
5
+ @problem = Problem.new
6
+ @parameter = Parameter.new
7
+ @parameter.cache_size = 1 # mb
8
+
9
+ # "eps is the stopping criterion (we usually use 0.00001 in nu-SVC,
10
+ # 0.001 in others)." (from README)
11
+ @parameter.eps = 0.001
12
+
13
+ @parameter.c = 10
14
+ end
15
+
16
+ it "has a nice API" do
17
+ example = {11 => 0.11, 21 => 0.21, 101 => 0.99 }.to_example
18
+ example.should == Node.features({11 => 0.11, 21 => 0.21, 101 => 0.99 })
19
+ end
20
+
21
+ it "is as in [PCI,217]" do
22
+ examples = [ [1,0,1], [-1,0,-1] ].map {|ary| Node.features(ary) }
23
+ labels = [1, -1]
24
+ @problem.set_examples(labels, examples)
25
+
26
+ model = Model.train(@problem, @parameter)
27
+
28
+ pred = model.predict(Node.features(1, 1, 1))
29
+ pred.should == 1.0
30
+
31
+ pred = model.predict(Node.features(-1, 1, -1))
32
+ pred.should == -1.0
33
+
34
+ pred = model.predict(Node.features(-1, 55, -1))
35
+ pred.should == -1.0
36
+ end
37
+
38
+ it "kernel parameter use" do
39
+ @parameter.kernel_type = Parameter::RBF
40
+ examples = [[1, 2, 3], [-2, -2, -2]].map {|ary| Node.features(ary) }
41
+ @problem.set_examples([1, 2], examples)
42
+
43
+ model = Model.train(@problem, @parameter)
44
+
45
+ model.predict(Node.features(1, 2, 3)).should == 2
46
+ end
47
+ end
data/tmp/.gitkeep ADDED
File without changes