svm_toolkit 1.1.7-java
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/LICENSE.rdoc +59 -0
- data/README.rdoc +103 -0
- data/bin/svm-demo +354 -0
- data/lib/PlotPackage.jar +0 -0
- data/lib/libsvm.jar +0 -0
- data/lib/svm_toolkit/evaluators.rb +169 -0
- data/lib/svm_toolkit/model.rb +124 -0
- data/lib/svm_toolkit/node.rb +21 -0
- data/lib/svm_toolkit/parameter.rb +117 -0
- data/lib/svm_toolkit/problem.rb +308 -0
- data/lib/svm_toolkit/svm.rb +224 -0
- data/lib/svm_toolkit.rb +37 -0
- metadata +79 -0
@@ -0,0 +1,224 @@
|
|
1
|
+
module SvmToolkit
|
2
|
+
|
3
|
+
# Extends the Java SVM class
|
4
|
+
#
|
5
|
+
# Available methods include:
|
6
|
+
#
|
7
|
+
# Svm.svm_train(problem, param)
|
8
|
+
#
|
9
|
+
# problem:: instance of Problem
|
10
|
+
# param:: instance of Parameter
|
11
|
+
#
|
12
|
+
# Returns an instance of Model
|
13
|
+
#
|
14
|
+
# Svm.svm_cross_validation(problem, param, nr_folds, target)
|
15
|
+
#
|
16
|
+
# problem:: instance of Problem
|
17
|
+
# param:: instance of Parameter
|
18
|
+
# nr_fold:: number of folds
|
19
|
+
# target:: resulting predictions in an array
|
20
|
+
#
|
21
|
+
class Svm
|
22
|
+
|
23
|
+
# Perform cross validation search on given gamma/cost values,
|
24
|
+
# using an RBF kernel,
|
25
|
+
# returning the best performing model and optionally displaying
|
26
|
+
# a contour map of performance.
|
27
|
+
#
|
28
|
+
# training_set:: instance of Problem, used for training
|
29
|
+
# cross_valn_set:: instance of Problem, used for evaluating models
|
30
|
+
# costs:: array of cost values to search across
|
31
|
+
# gammas:: array of gamma values to search across
|
32
|
+
# params:: Optional parameters include:
|
33
|
+
# * :evaluator => Evaluator::OverallAccuracy, the name of the class
|
34
|
+
# to use for computing performance
|
35
|
+
# * :show_plot => false, whether to display contour plot
|
36
|
+
#
|
37
|
+
# Returns an instance of Model, the best performing model.
|
38
|
+
#
|
39
|
+
def Svm.cross_validation_search(training_set, cross_valn_set,
|
40
|
+
costs = [-2,-1,0,1,2,3].collect {|i| 2**i},
|
41
|
+
gammas = [-2,-1,0,1,2,3].collect {|i| 2**i},
|
42
|
+
params = {})
|
43
|
+
evaluator = params.fetch :evaluator, Evaluator::OverallAccuracy
|
44
|
+
show_plot = params.fetch :show_plot, false
|
45
|
+
|
46
|
+
fjp = ForkJoinPool.new
|
47
|
+
task = CrossValidationSearch.new gammas, costs, training_set, cross_valn_set, evaluator
|
48
|
+
results, best_model = fjp.invoke task
|
49
|
+
|
50
|
+
if show_plot
|
51
|
+
ContourDisplay.new(costs.collect {|n| Math.log2(n)},
|
52
|
+
gammas.collect {|n| Math.log2(n)},
|
53
|
+
results)
|
54
|
+
end
|
55
|
+
|
56
|
+
return best_model
|
57
|
+
end
|
58
|
+
|
59
|
+
private
|
60
|
+
# Set up the cross validation search across a cost/gamma pair
|
61
|
+
class CrossValidationSearch < RecursiveTask
|
62
|
+
# Creates an instance of the CrossValidationSearch.
|
63
|
+
#
|
64
|
+
# gammas:: array of gamma values to search over
|
65
|
+
# costs:: array of cost values to search over
|
66
|
+
# training_set:: for building the model
|
67
|
+
# cross_valn_set:: for testing the model
|
68
|
+
# evaluator:: name of Evaluator class, used for evaluating the model
|
69
|
+
#
|
70
|
+
def initialize gammas, costs, training_set, cross_valn_set, evaluator
|
71
|
+
super()
|
72
|
+
|
73
|
+
@gammas = gammas
|
74
|
+
@costs = costs
|
75
|
+
@training_set = training_set
|
76
|
+
@cross_valn_set = cross_valn_set
|
77
|
+
@evaluator = evaluator
|
78
|
+
end
|
79
|
+
|
80
|
+
# perform actual computation, return results/best_model
|
81
|
+
def compute
|
82
|
+
tasks = []
|
83
|
+
# create one task per gamma/cost pair
|
84
|
+
@gammas.each do |gamma|
|
85
|
+
@costs.each do |cost|
|
86
|
+
tasks << SvmTrainer.new(@training_set, Parameter.new(
|
87
|
+
:svm_type => Parameter::C_SVC,
|
88
|
+
:kernel_type => Parameter::RBF,
|
89
|
+
:cost => cost,
|
90
|
+
:gamma => gamma
|
91
|
+
), @cross_valn_set, @evaluator)
|
92
|
+
end
|
93
|
+
end
|
94
|
+
|
95
|
+
# set off all the tasks
|
96
|
+
tasks.each do |task|
|
97
|
+
task.fork
|
98
|
+
end
|
99
|
+
|
100
|
+
# collect the results
|
101
|
+
results = []
|
102
|
+
best_model = nil
|
103
|
+
lowest_error = nil
|
104
|
+
|
105
|
+
@gammas.each do |gamma|
|
106
|
+
results_row = []
|
107
|
+
@costs.each do |cost|
|
108
|
+
task = tasks.shift
|
109
|
+
model, result = task.join
|
110
|
+
|
111
|
+
if result.better_than? lowest_error
|
112
|
+
best_model = model
|
113
|
+
lowest_error = result
|
114
|
+
end
|
115
|
+
puts "Result for cost = #{cost} gamma = #{gamma} is #{result.value}"
|
116
|
+
results_row << result.value
|
117
|
+
end
|
118
|
+
results << results_row
|
119
|
+
end
|
120
|
+
|
121
|
+
return results, best_model
|
122
|
+
end
|
123
|
+
end
|
124
|
+
|
125
|
+
# Represent a single training task for an SVM RBF model
|
126
|
+
class SvmTrainer < RecursiveTask
|
127
|
+
|
128
|
+
# Creates an instance of an SvmTrainer.
|
129
|
+
#
|
130
|
+
# training_set:: used to train the model
|
131
|
+
# parameters:: parameters for building the model
|
132
|
+
# cross_valn_set:: used to test the model performance
|
133
|
+
# evaluator:: class name of Evaluator to use for evaluating the model performance
|
134
|
+
#
|
135
|
+
def initialize training_set, parameters, cross_valn_set, evaluator
|
136
|
+
super()
|
137
|
+
|
138
|
+
@training_set = training_set
|
139
|
+
@parameters = parameters
|
140
|
+
@cross_valn_set = cross_valn_set
|
141
|
+
@evaluator = evaluator
|
142
|
+
end
|
143
|
+
|
144
|
+
# Trains and evaluates a model, using the parameters.
|
145
|
+
#
|
146
|
+
# Returns the model and evaluation.
|
147
|
+
#
|
148
|
+
def compute
|
149
|
+
model = Svm.svm_train @training_set, @parameters
|
150
|
+
result = model.evaluate_dataset @cross_valn_set, :evaluator => @evaluator
|
151
|
+
return model, result
|
152
|
+
end
|
153
|
+
end
|
154
|
+
|
155
|
+
# Swing Frame displaying the cross-validation performance.
|
156
|
+
#
|
157
|
+
class ContourDisplay < javax.swing.JFrame
|
158
|
+
# Creates an instance of the ContourDisplay.
|
159
|
+
#
|
160
|
+
# xs:: array of x-coordinates
|
161
|
+
# ys:: array of y-coordinates
|
162
|
+
# zs:: array of values for matching (x, y) coordinate
|
163
|
+
#
|
164
|
+
def initialize(xs, ys, zs)
|
165
|
+
super("Cross-Validation Performance")
|
166
|
+
self.setSize(500, 400)
|
167
|
+
|
168
|
+
cxs = Java::double[][ys.size].new
|
169
|
+
cys = Java::double[][ys.size].new
|
170
|
+
ys.size.times do |i|
|
171
|
+
cxs[i] = Java::double[xs.size].new
|
172
|
+
cys[i] = Java::double[xs.size].new
|
173
|
+
xs.size.times do |j|
|
174
|
+
cxs[i][j] = xs[j]
|
175
|
+
cys[i][j] = ys[i]
|
176
|
+
end
|
177
|
+
end
|
178
|
+
|
179
|
+
czs = Java::double[][ys.size].new
|
180
|
+
ys.size.times do |i|
|
181
|
+
czs[i] = Java::double[xs.size].new
|
182
|
+
xs.size.times do |j|
|
183
|
+
czs[i][j] = zs[i][j]
|
184
|
+
end
|
185
|
+
end
|
186
|
+
|
187
|
+
plot = ContourPlot.new(
|
188
|
+
cxs,
|
189
|
+
cys,
|
190
|
+
czs,
|
191
|
+
10,
|
192
|
+
false,
|
193
|
+
"",
|
194
|
+
"Cost (log-scale)",
|
195
|
+
"Gamma (log-scale)",
|
196
|
+
nil,
|
197
|
+
nil
|
198
|
+
)
|
199
|
+
plot.colorizeContours(java.awt::Color.green, java.awt::Color.red)
|
200
|
+
|
201
|
+
symbol = DiamondSymbol.new
|
202
|
+
symbol.border_color = java.awt::Color.blue
|
203
|
+
symbol.fill_color = java.awt::Color.blue
|
204
|
+
symbol.size = 4
|
205
|
+
|
206
|
+
run = PlotRun.new
|
207
|
+
ys.size.times do |i|
|
208
|
+
xs.size.times do |j|
|
209
|
+
run.add(PlotDatum.new(cxs[i][j], cys[i][j], false, symbol))
|
210
|
+
end
|
211
|
+
end
|
212
|
+
|
213
|
+
plot.runs << run
|
214
|
+
|
215
|
+
panel = PlotPanel.new(plot)
|
216
|
+
panel.background = java.awt::Color.white
|
217
|
+
add panel
|
218
|
+
|
219
|
+
self.setDefaultCloseOperation(javax.swing.WindowConstants::DISPOSE_ON_CLOSE)
|
220
|
+
self.visible = true
|
221
|
+
end
|
222
|
+
end
|
223
|
+
end
|
224
|
+
end
|
data/lib/svm_toolkit.rb
ADDED
@@ -0,0 +1,37 @@
|
|
1
|
+
|
2
|
+
require 'confusion_matrix'
|
3
|
+
require 'csv'
|
4
|
+
|
5
|
+
require 'java'
|
6
|
+
require 'libsvm'
|
7
|
+
require 'PlotPackage'
|
8
|
+
|
9
|
+
# Containing module for this library.
|
10
|
+
#
|
11
|
+
module SvmToolkit
|
12
|
+
# import the required java classes - must do this before loading the ruby classes
|
13
|
+
|
14
|
+
java_import 'libsvm.Parameter'
|
15
|
+
java_import 'libsvm.Model'
|
16
|
+
java_import 'libsvm.Problem'
|
17
|
+
java_import 'libsvm.Node'
|
18
|
+
java_import 'libsvm.Svm'
|
19
|
+
|
20
|
+
java_import 'java.util.concurrent.ForkJoinPool'
|
21
|
+
java_import 'java.util.concurrent.RecursiveTask'
|
22
|
+
|
23
|
+
java_import 'jahuwaldt.plot.ContourPlot'
|
24
|
+
java_import 'jahuwaldt.plot.DiamondSymbol'
|
25
|
+
java_import 'jahuwaldt.plot.PlotDatum'
|
26
|
+
java_import 'jahuwaldt.plot.PlotPanel'
|
27
|
+
java_import 'jahuwaldt.plot.PlotRun'
|
28
|
+
end
|
29
|
+
|
30
|
+
# finally require the ruby code which extends the Java classes
|
31
|
+
require 'svm_toolkit/evaluators'
|
32
|
+
require 'svm_toolkit/model'
|
33
|
+
require 'svm_toolkit/node'
|
34
|
+
require 'svm_toolkit/parameter'
|
35
|
+
require 'svm_toolkit/problem'
|
36
|
+
require 'svm_toolkit/svm'
|
37
|
+
|
metadata
ADDED
@@ -0,0 +1,79 @@
|
|
1
|
+
--- !ruby/object:Gem::Specification
|
2
|
+
name: svm_toolkit
|
3
|
+
version: !ruby/object:Gem::Version
|
4
|
+
version: 1.1.7
|
5
|
+
platform: java
|
6
|
+
authors:
|
7
|
+
- Peter Lane
|
8
|
+
autorequire:
|
9
|
+
bindir: bin
|
10
|
+
cert_chain: []
|
11
|
+
date: 2023-02-05 00:00:00.000000000 Z
|
12
|
+
dependencies:
|
13
|
+
- !ruby/object:Gem::Dependency
|
14
|
+
requirement: !ruby/object:Gem::Requirement
|
15
|
+
requirements:
|
16
|
+
- - "~>"
|
17
|
+
- !ruby/object:Gem::Version
|
18
|
+
version: '1.0'
|
19
|
+
name: confusion_matrix
|
20
|
+
prerelease: false
|
21
|
+
type: :runtime
|
22
|
+
version_requirements: !ruby/object:Gem::Requirement
|
23
|
+
requirements:
|
24
|
+
- - "~>"
|
25
|
+
- !ruby/object:Gem::Version
|
26
|
+
version: '1.0'
|
27
|
+
description: "Support-vector machines are a popular tool in data mining. This package\
|
28
|
+
\ includes an amended version of the Java implementation of the libsvm library (version\
|
29
|
+
\ 3.11). Additional methods and examples are provided to support standard training\
|
30
|
+
\ techniques, such as cross-validation, various alternative evaluation methods,\
|
31
|
+
\ such as overall accuracy, precision or recall, and simple visualisations. \n"
|
32
|
+
email: peterlane@gmx.com
|
33
|
+
executables:
|
34
|
+
- svm-demo
|
35
|
+
extensions: []
|
36
|
+
extra_rdoc_files:
|
37
|
+
- README.rdoc
|
38
|
+
- LICENSE.rdoc
|
39
|
+
files:
|
40
|
+
- LICENSE.rdoc
|
41
|
+
- README.rdoc
|
42
|
+
- bin/svm-demo
|
43
|
+
- lib/PlotPackage.jar
|
44
|
+
- lib/libsvm.jar
|
45
|
+
- lib/svm_toolkit.rb
|
46
|
+
- lib/svm_toolkit/evaluators.rb
|
47
|
+
- lib/svm_toolkit/model.rb
|
48
|
+
- lib/svm_toolkit/node.rb
|
49
|
+
- lib/svm_toolkit/parameter.rb
|
50
|
+
- lib/svm_toolkit/problem.rb
|
51
|
+
- lib/svm_toolkit/svm.rb
|
52
|
+
homepage:
|
53
|
+
licenses:
|
54
|
+
- MIT
|
55
|
+
metadata: {}
|
56
|
+
post_install_message: "'svm-demo' should now be available via your jruby path. If\
|
57
|
+
\ not, add \nC:/jruby-9.4.0.0/lib/ruby/gems/shared/gems/svm_toolkit-1.1.7-java/bin\n\
|
58
|
+
\ to your path."
|
59
|
+
rdoc_options:
|
60
|
+
- "-m"
|
61
|
+
- README.rdoc
|
62
|
+
require_paths:
|
63
|
+
- lib
|
64
|
+
required_ruby_version: !ruby/object:Gem::Requirement
|
65
|
+
requirements:
|
66
|
+
- - ">="
|
67
|
+
- !ruby/object:Gem::Version
|
68
|
+
version: '0'
|
69
|
+
required_rubygems_version: !ruby/object:Gem::Requirement
|
70
|
+
requirements:
|
71
|
+
- - ">="
|
72
|
+
- !ruby/object:Gem::Version
|
73
|
+
version: '0'
|
74
|
+
requirements: []
|
75
|
+
rubygems_version: 3.3.25
|
76
|
+
signing_key:
|
77
|
+
specification_version: 4
|
78
|
+
summary: A JRuby wrapper around the libsvm library, with additional functionality.
|
79
|
+
test_files: []
|