jrb-libsvm 0.1.2-java

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,23 @@
1
+ # -*- encoding: utf-8 -*-
2
+ lib = File.expand_path('../lib', __FILE__)
3
+ $LOAD_PATH.unshift(lib) unless $LOAD_PATH.include?(lib)
4
+ require 'jrb-libsvm/version'
5
+
6
+ Gem::Specification.new do |gem|
7
+ gem.name = "jrb-libsvm"
8
+ gem.version = JrbLibsvm::VERSION
9
+ gem.platform = 'java'
10
+ gem.authors = ["Andreas Eger"]
11
+ gem.email = ["dev@eger-andreas.de"]
12
+ gem.description = %q{JRuby language bindings for libsvm}
13
+ gem.summary = %q{basic wrapper around the java libsvm libary}
14
+ gem.homepage = "https://github.com/sch1zo/jrb-libsvm"
15
+
16
+ gem.files = `git ls-files`.split($/)
17
+ gem.executables = gem.files.grep(%r{^bin/}).map{ |f| File.basename(f) }
18
+ gem.test_files = gem.files.grep(%r{^(test|spec|features)/})
19
+ gem.require_paths = ["lib"]
20
+
21
+ # specify any dependencies here
22
+ gem.add_development_dependency('rspec', '>= 2.7.0')
23
+ end
Binary file
@@ -0,0 +1,97 @@
1
+ java_import 'java.io.DataOutputStream'
2
+ java_import 'java.io.ByteArrayOutputStream'
3
+
4
+ java_import 'java.io.StringReader'
5
+ java_import 'java.io.BufferedReader'
6
+
7
+ module Libsvm
8
+ class Model
9
+ class << self
10
+ def train(problem, parameter)
11
+ return Svm.svm_train(problem, parameter)
12
+ end
13
+ end
14
+
15
+ # Return the SVM problem type for this model
16
+ def svm_type
17
+ self.param.svm_type
18
+ end
19
+
20
+ # Return the kernel type for this model
21
+ def kernel_type
22
+ self.param.kernel_type
23
+ end
24
+
25
+ # Return the value of the degree parameter
26
+ def degree
27
+ self.param.degree
28
+ end
29
+
30
+ # Return the value of the gamma parameter
31
+ def gamma
32
+ self.param.gamma
33
+ end
34
+
35
+ # Return the value of the cost parameter
36
+ def cost
37
+ self.param.c
38
+ end
39
+
40
+ # Return the number of classes handled by this model.
41
+ def classes
42
+ self.nr_class
43
+ end
44
+
45
+ # Save model to given filename.
46
+ # Raises IOError on any error.
47
+ def save filename
48
+ Svm.svm_save_model(filename, self)
49
+ rescue e = java.io.IOException
50
+ raise IOError.new "Error in saving SVM model to file: #{e}"
51
+ end
52
+
53
+ # Serialize model and return a string
54
+ def serialize
55
+ stream = ByteArrayOutputStream.new
56
+ do_stream = DataOutputStream.new(stream)
57
+ Svm.svm_save_model(do_stream, self)
58
+ stream.to_s
59
+ rescue java.io.IOException
60
+ raise IOError.new "Error in serializing SVM model"
61
+ end
62
+ def to_s
63
+ serialize
64
+ end
65
+
66
+ # Load model from given filename.
67
+ # Raises IOError on any error.
68
+ def self.load filename
69
+ Svm.svm_load_model(filename)
70
+ rescue java.io.IOException
71
+ raise IOError.new "Error in loading SVM model from file"
72
+ end
73
+
74
+ # Load model from string.
75
+ def self.parse string
76
+ reader = BufferedReader.new(StringReader.new(string))
77
+ Svm.svm_load_model(reader)
78
+ rescue java.io.IOException
79
+ raise IOError.new "Error in loading SVM model from string"
80
+ end
81
+
82
+ def predict(example, &block)
83
+ if block.nil?
84
+ return Svm.svm_predict(self, example)
85
+ else
86
+ prediction, probabilities = predict_probability(example)
87
+ yield probabilities
88
+ return prediction
89
+ end
90
+ end
91
+
92
+ def predict_probability(example)
93
+ probabilities = Java::double[self.classes].new
94
+ return Svm.svm_predict_probability(self, example, probabilities), probabilities
95
+ end
96
+ end
97
+ end
@@ -0,0 +1,37 @@
1
+ module Libsvm
2
+ class Node
3
+ class << self
4
+ def features(*vargs)
5
+ array_of_nodes = []
6
+ if vargs.size == 1
7
+ if vargs.first.class == Array
8
+ vargs.first.each_with_index do |value, index|
9
+ array_of_nodes << Node.new(index.to_i, value.to_f)
10
+ end
11
+ elsif vargs.first.class == Hash
12
+ vargs.first.each do |index, value|
13
+ array_of_nodes << Node.new(index.to_i, value.to_f)
14
+ end
15
+ else
16
+ raise(ArgumentError.new("Node features need to be a Hash, Array or Floats"))
17
+ end
18
+ else
19
+ vargs.each_with_index do |value, index|
20
+ array_of_nodes << Node.new(index.to_i, value.to_f)
21
+ end
22
+ end
23
+ array_of_nodes
24
+ end
25
+ end
26
+
27
+ def initialize(index=0, value=0.0)
28
+ super()
29
+ self.index = index
30
+ self.value = value
31
+ end
32
+
33
+ def ==(other)
34
+ index == other.index && value == other.value
35
+ end
36
+ end
37
+ end
@@ -0,0 +1,66 @@
1
+ module Libsvm
2
+ class Parameter
3
+ alias :c :C
4
+ alias :c= :C=
5
+
6
+ def label_weights
7
+ Hash[self.weight_label.zip(self.weight)]
8
+ end
9
+ def label_weights=v
10
+ self.nr_weight = v.keys.size
11
+ self.weight_label = v.keys.to_java :int
12
+ self.weight = v.values.to_java :double
13
+ end
14
+
15
+ # Constructor sets up values of attributes based on provided map.
16
+ # Valid keys with their default values:
17
+ # * :svm_type = Parameter::C_SVC, for the type of SVM
18
+ # * :kernel_type = Parameter::LINEAR, for the type of kernel
19
+ # * :cost = 1.0, for the cost or C parameter
20
+ # * :gamma = 0.0, for the gamma parameter in kernel
21
+ # * :degree = 1, for polynomial kernel
22
+ # * :coef0 = 0.0, for polynomial/sigmoid kernels
23
+ # * :eps = 0.001, for stopping criterion
24
+ # * :nr_weight = 0, for C_SVC
25
+ # * :nu = 0.5, used for NU_SVC, ONE_CLASS and NU_SVR. Nu must be in (0,1]
26
+ # * :p = 0.1, used for EPSILON_SVR
27
+ # * :shrinking = 1, use the shrinking heuristics
28
+ # * :probability = 0, use the probability estimates
29
+ def initialize args={}
30
+ super()
31
+ self.svm_type = args.fetch(:svm_type, Parameter::C_SVC)
32
+ self.kernel_type = args.fetch(:kernel_type, Parameter::LINEAR)
33
+ self.C = args.fetch(:cost, 1.0)
34
+ self.gamma = args.fetch(:gamma, 0.0)
35
+ self.degree = args.fetch(:degree, 1)
36
+ self.coef0 = args.fetch(:coef0, 0.0)
37
+ self.eps = args.fetch(:eps, 0.001)
38
+ self.nr_weight = args.fetch(:nr_weight, 0)
39
+ self.nu = args.fetch(:nu, 0.5)
40
+ self.p = args.fetch(:p, 0.1)
41
+ self.shrinking = args.fetch(:shrinking, 1)
42
+ self.probability = args.fetch(:probability, 0)
43
+ self.cache_size = args.fetch(:cache_size, 1)
44
+
45
+ unless self.nu > 0.0 and self.nu <= 1.0
46
+ raise ArgumentError "Invalid value of nu #{self.nu}, should be in (0,1]"
47
+ end
48
+ end
49
+ end
50
+ SvmParameter = Parameter
51
+
52
+ module KernelType
53
+ LINEAR = Parameter::LINEAR
54
+ POLY = Parameter::POLY
55
+ RBF = Parameter::RBF
56
+ SIGMOID = Parameter::SIGMOID
57
+ PRECOMPUTED = Parameter::PRECOMPUTED
58
+ end
59
+ module SvmType
60
+ C_SVC = Parameter::C_SVC
61
+ NU_SVC = Parameter::NU_SVC
62
+ ONE_CLASS = Parameter::ONE_CLASS
63
+ EPSILON_SVR = Parameter::EPSILON_SVR
64
+ NU_SVR = Parameter::NU_SVR
65
+ end
66
+ end
@@ -0,0 +1,35 @@
1
+ module Libsvm
2
+ class Problem
3
+ def set_examples(labels, features)
4
+ unless features.size == labels.size
5
+ raise ArgumentError.new "Number of features must equal number of labels"
6
+ end
7
+ unless features.size > 0
8
+ raise ArgumentError.new "There must be at least one feature."
9
+ end
10
+ unless features.collect {|i| i.size}.min == features.collect {|i| i.size}.max
11
+ raise ArgumentError.new "All features must have the same size"
12
+ end
13
+
14
+ self.l = labels.size
15
+ # -- add in the training data
16
+ self.x = Node[features.size, features[0].size].new
17
+ features.each.with_index do |feature, i|
18
+ feature.each.with_index do |node, j|
19
+ self.x[i][j] = node
20
+ end
21
+ end
22
+
23
+ # -- add in the labels
24
+ self.y = Java::double[labels.size].new
25
+ labels.each.with_index do |label, i|
26
+ self.y[i] = label
27
+ end
28
+
29
+ return labels.size
30
+ end
31
+ def examples
32
+ return self.y, self.x
33
+ end
34
+ end
35
+ end
@@ -0,0 +1,3 @@
1
+ module JrbLibsvm
2
+ VERSION = "0.1.2"
3
+ end
data/lib/jrb-libsvm.rb ADDED
@@ -0,0 +1,31 @@
1
+ require_relative "jrb-libsvm/version"
2
+ require "java"
3
+ require_relative "java/libsvm"
4
+
5
+ module Libsvm
6
+ java_import "libsvm.Parameter"
7
+ java_import "libsvm.Model"
8
+ java_import "libsvm.Problem"
9
+ java_import "libsvm.Node"
10
+ java_import "libsvm.Svm"
11
+
12
+ module CoreExtensions
13
+ module Collection
14
+ def to_example
15
+ Node.features(self)
16
+ end
17
+ end
18
+ end
19
+ end
20
+
21
+ require_relative 'jrb-libsvm/parameter'
22
+ require_relative 'jrb-libsvm/model'
23
+ require_relative 'jrb-libsvm/node'
24
+ require_relative 'jrb-libsvm/problem'
25
+
26
+ class Hash
27
+ include Libsvm::CoreExtensions::Collection
28
+ end
29
+ class Array
30
+ include Libsvm::CoreExtensions::Collection
31
+ end
@@ -0,0 +1,119 @@
1
+ require "spec_helper"
2
+
3
+ describe Model do
4
+ def create_example
5
+ Node.features(0.2, 0.3, 0.4, 0.5)
6
+ end
7
+
8
+ def create_problem
9
+ problem = Problem.new
10
+ features = [Node.features([0.2,0.3,0.4,0.4]),
11
+ Node.features([0.1,0.5,0.1,0.9]),
12
+ Node.features([0.2,0.2,0.6,0.5]),
13
+ Node.features([0.3,0.1,0.5,0.9])]
14
+ problem.set_examples([1,2,1,2], features)
15
+ problem
16
+ end
17
+
18
+ def create_parameter
19
+ parameter = Parameter.new
20
+ parameter.cache_size = 50 # mb
21
+ parameter.eps = 0.01
22
+ parameter.c = 10
23
+ parameter.probability = 1
24
+ parameter
25
+ end
26
+
27
+ context "The class interface" do
28
+ before(:each) do
29
+ @problem = create_problem
30
+ @parameter = create_parameter
31
+ end
32
+
33
+ it "results from training on a problem under a certain parameter set" do
34
+ model = Model.train(@problem,@parameter)
35
+ model.should_not be_nil
36
+ end
37
+ end
38
+
39
+ context "A saved model" do
40
+ before(:each) do
41
+ @filename = "tmp/svm_model.model"
42
+ model = Model.train(create_problem, create_parameter)
43
+ model.save(@filename)
44
+ @model_string = model.to_s
45
+ end
46
+
47
+ it "can be loaded from a file" do
48
+ model = Model.load(@filename)
49
+ model.should_not be_nil
50
+ end
51
+
52
+ it "can be loaded from a string" do
53
+ model = Model.parse @model_string
54
+ model.should_not be_nil
55
+ end
56
+
57
+ it "should do the same for load and load_from_string" do
58
+ model = Model.load @filename
59
+ model2 = Model.parse @model_string
60
+ model.serialize.should == model2.serialize
61
+ end
62
+
63
+ after(:each) do
64
+ File.delete(@filename) rescue nil
65
+ end
66
+ end
67
+
68
+ context "An Libsvm model" do
69
+ before(:each) do
70
+ @problem = create_problem
71
+ @parameter = create_parameter
72
+ @model = Model.train(@problem, @parameter)
73
+ @file_path = "tmp/svm_model.model"
74
+ end
75
+ after(:each) do
76
+ File.delete(@file_path) if File.exists?(@file_path)
77
+ end
78
+
79
+ it "can be saved to a file" do
80
+ @model.save(@file_path)
81
+ File.exist?(@file_path).should be_true
82
+ end
83
+
84
+ it "can be serialized to a string" do
85
+ @model.serialize.should_not be_empty
86
+ end
87
+
88
+ it "should generate the same text for serialize and save" do
89
+ @model.save(@file_path)
90
+ @model.serialize.should == IO.read(@file_path)
91
+ end
92
+
93
+ it "can be asked for it's svm_type" do
94
+ @model.svm_type.should == Parameter::C_SVC
95
+ end
96
+
97
+ it "can be asked for it's number of classes (aka. labels)" do
98
+ @model.classes.should == 2
99
+ end
100
+
101
+ it "can predict" do
102
+ prediction = @model.predict(create_example)
103
+ prediction.should_not be_nil
104
+ end
105
+ it "can predict probability" do
106
+ prediction, probabilities = @model.predict_probability(create_example)
107
+ prediction.should_not be_nil
108
+ probabilities.should have(@model.classes).items
109
+ probabilities.each { |e| e.should_not be_nil }
110
+ end
111
+ it "can predict with block" do
112
+ prediction = @model.predict(create_example) do |probabilities|
113
+ probabilities.should be_all { |p| p.kind_of? Float }
114
+ probabilities.count.should == @model.classes
115
+ end
116
+ prediction.should be_a Float
117
+ end
118
+ end
119
+ end
data/spec/node_spec.rb ADDED
@@ -0,0 +1,62 @@
1
+ require "spec_helper"
2
+
3
+ describe Node do
4
+ context "construction" do
5
+ it "using the properties" do
6
+ n = Node.new
7
+ n.index = 11
8
+ n.value = 0.11
9
+ n.index.should == 11
10
+ n.value.should be_within(0.0001).of(0.11)
11
+ end
12
+
13
+ it "using the constructor parameters" do
14
+ n = Node.new(14, 0.14)
15
+ n.index.should == 14
16
+ n.value.should be_within(0.0001).of(0.14)
17
+ end
18
+ end
19
+
20
+ context "inner workings" do
21
+ let(:node) {Node.new}
22
+
23
+ it "can be created" do
24
+ node.should_not be_nil
25
+ end
26
+
27
+ it "does not segfault on setting properties" do
28
+ node.index = 99
29
+ node.index.should == 99
30
+ node.value = 3.141
31
+ node.value.should be_within(0.00001).of(3.141)
32
+ end
33
+
34
+ it "has inited properties" do
35
+ node.index.should == 0
36
+ node.value.should be_within(0.00001).of(0)
37
+ end
38
+
39
+ it "class can create nodes from an array" do
40
+ ary = Node.features([0.1, 0.2, 0.3, 0.4, 0.5])
41
+ ary.map {|n| n.class.should == Node}
42
+ ary.map {|n| n.value }.should == [0.1, 0.2, 0.3, 0.4, 0.5]
43
+ end
44
+
45
+ it "class can create nodes from variable parameters" do
46
+ ary = Node.features(0.1, 0.2, 0.3, 0.4, 0.5)
47
+ ary.map {|n| Node.should === n }
48
+ ary.map {|n| n.value }.should == [0.1, 0.2, 0.3, 0.4, 0.5]
49
+ end
50
+
51
+ it "class can create nodes from hash" do
52
+ ary = Node.features(3=>0.3, 5=>0.5, 6=>0.6, 10=>1.0)
53
+ ary.map {|n| n.class.should == Node}
54
+ ary.map {|n| n.value }.sort.should == [0.3, 0.5, 0.6, 1.0]
55
+ ary.map {|n| n.index }.sort.should == [3, 5, 6, 10]
56
+ end
57
+
58
+ it "implements a value-like equality, not identity-notion" do
59
+ Node.new(1, 0.1).should == Node.new(1, 0.1)
60
+ end
61
+ end
62
+ end
@@ -0,0 +1,79 @@
1
+ require "spec_helper"
2
+
3
+ describe Parameter do
4
+ before do
5
+ @p = Libsvm::SvmParameter.new
6
+ end
7
+ it "can be created with a constructor" do
8
+ ->{Libsvm::SvmParameter.new(svm_type: Libsvm::SvmType::C_SVC, cost: 23, gamma: 65)}.should_not raise_error
9
+ end
10
+ it "int svm_type" do
11
+ SvmType::C_SVC.should == 0
12
+ @p.svm_type = SvmType::C_SVC
13
+ @p.svm_type.should == SvmType::C_SVC
14
+ end
15
+
16
+ it "int kernel_type" do
17
+ KernelType::RBF.should == 2
18
+ @p.kernel_type = KernelType::RBF
19
+ @p.kernel_type.should == KernelType::RBF
20
+ end
21
+
22
+ it "int degree" do
23
+ @p.degree = 99
24
+ @p.degree.should == 99
25
+ end
26
+
27
+ it "double gamma" do
28
+ @p.gamma = 0.33
29
+ @p.gamma.should == 0.33
30
+ end
31
+
32
+ it "double coef0" do
33
+ @p.coef0 = 0.99
34
+ @p.coef0.should == 0.99
35
+ end
36
+
37
+ it "double cache_size" do
38
+ @p.cache_size = 0.77
39
+ @p.cache_size.should == 0.77
40
+ end
41
+
42
+ it "double eps" do
43
+ @p.eps = 0.111
44
+ @p.eps.should == 0.111
45
+ @p.eps = 0.112
46
+ @p.eps.should == 0.112
47
+ end
48
+
49
+ it "double C" do
50
+ @p.c = 3.141
51
+ @p.c.should == 3.141
52
+ end
53
+
54
+ it "can set and read weights (weight, weight_label, nr_weight members from struct)" do
55
+ @p.label_weights = {1=> 1.2, 3=>0.2, 5=>0.888}
56
+ @p.label_weights.should == {1=> 1.2, 3=>0.2, 5=>0.888}
57
+ end
58
+
59
+
60
+ it "double nu" do
61
+ @p.nu = 1.1
62
+ @p.nu.should == 1.1
63
+ end
64
+
65
+ it "double p" do
66
+ @p.p = 0.123
67
+ @p.p.should == 0.123
68
+ end
69
+
70
+ it "int shrinking" do
71
+ @p.shrinking = 22
72
+ @p.shrinking.should == 22
73
+ end
74
+
75
+ it "int probability" do
76
+ @p.probability = 35
77
+ @p.probability.should == 35
78
+ end
79
+ end
@@ -0,0 +1,37 @@
1
+ require 'spec_helper'
2
+
3
+ describe Problem do
4
+ before(:each) do
5
+ @problem = Problem.new
6
+ @features = [ Node.features(0.2,0.3,0.4,0.4),
7
+ Node.features(0.1,0.5,0.1,0.9),
8
+ Node.features(0.2,0.2,0.6,0.5),
9
+ Node.features(0.3,0.1,0.5,0.9) ]
10
+ end
11
+
12
+ it "examples get stored and retrieved" do
13
+ @problem.set_examples([1,2,1,2], @features)
14
+ labels, examples = @problem.examples
15
+ labels.size.should == 4
16
+ examples.size.should == 4
17
+ examples.map {|x|x.size}.should == [4,4,4,4]
18
+ examples.first.map {|node| node.index}.should == [0,1,2,3]
19
+ examples.first.map {|node| node.value}.should == [0.2,0.3,0.4,0.4]
20
+ end
21
+
22
+ it "can be populated" do
23
+ examples = [Node.features(0.2,0.3,0.4,0.4),
24
+ Node.features(0.1,0.5,0.1,0.9),
25
+ Node.features(0.2,0.2,0.6,0.5),
26
+ Node.features(0.3,0.1,0.5,0.9)]
27
+ ->{@problem.set_examples([1,2,1,2], examples)}.should_not raise_error
28
+ end
29
+
30
+ it "can be set twice over" do
31
+ features = [Node.features(0.2, 0.3, 0.4, 0.4), Node.features(0.3,0.1,0.5,0.9)]
32
+ @problem.set_examples([1,2], features)
33
+ features = [Node.features(0.2, 0.3, 0.4, 0.4), Node.features(0.3,0.1,0.5,0.9)]
34
+ ->{@problem.set_examples([8,2], features)}.should_not raise_error
35
+ end
36
+
37
+ end
@@ -0,0 +1,9 @@
1
+ require 'bundler'
2
+ Bundler.setup
3
+ Bundler.require(:default, :test)
4
+ require 'jrb-libsvm'
5
+
6
+ include Libsvm
7
+
8
+ RSpec.configure do |config|
9
+ end
@@ -0,0 +1,47 @@
1
+ require 'spec_helper'
2
+
3
+ describe "Basic usage" do
4
+ before do
5
+ @problem = Problem.new
6
+ @parameter = Parameter.new
7
+ @parameter.cache_size = 1 # mb
8
+
9
+ # "eps is the stopping criterion (we usually use 0.00001 in nu-SVC,
10
+ # 0.001 in others)." (from README)
11
+ @parameter.eps = 0.001
12
+
13
+ @parameter.c = 10
14
+ end
15
+
16
+ it "has a nice API" do
17
+ example = {11 => 0.11, 21 => 0.21, 101 => 0.99 }.to_example
18
+ example.should == Node.features({11 => 0.11, 21 => 0.21, 101 => 0.99 })
19
+ end
20
+
21
+ it "is as in [PCI,217]" do
22
+ examples = [ [1,0,1], [-1,0,-1] ].map {|ary| Node.features(ary) }
23
+ labels = [1, -1]
24
+ @problem.set_examples(labels, examples)
25
+
26
+ model = Model.train(@problem, @parameter)
27
+
28
+ pred = model.predict(Node.features(1, 1, 1))
29
+ pred.should == 1.0
30
+
31
+ pred = model.predict(Node.features(-1, 1, -1))
32
+ pred.should == -1.0
33
+
34
+ pred = model.predict(Node.features(-1, 55, -1))
35
+ pred.should == -1.0
36
+ end
37
+
38
+ it "kernel parameter use" do
39
+ @parameter.kernel_type = Parameter::RBF
40
+ examples = [[1, 2, 3], [-2, -2, -2]].map {|ary| Node.features(ary) }
41
+ @problem.set_examples([1, 2], examples)
42
+
43
+ model = Model.train(@problem, @parameter)
44
+
45
+ model.predict(Node.features(1, 2, 3)).should == 2
46
+ end
47
+ end
data/tmp/.gitkeep ADDED
File without changes