svm_toolkit 1.1.7-java

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,224 @@
1
+ module SvmToolkit
2
+
3
+ # Extends the Java SVM class
4
+ #
5
+ # Available methods include:
6
+ #
7
+ # Svm.svm_train(problem, param)
8
+ #
9
+ # problem:: instance of Problem
10
+ # param:: instance of Parameter
11
+ #
12
+ # Returns an instance of Model
13
+ #
14
+ # Svm.svm_cross_validation(problem, param, nr_folds, target)
15
+ #
16
+ # problem:: instance of Problem
17
+ # param:: instance of Parameter
18
+ # nr_fold:: number of folds
19
+ # target:: resulting predictions in an array
20
+ #
21
+ class Svm
22
+
23
+ # Perform cross validation search on given gamma/cost values,
24
+ # using an RBF kernel,
25
+ # returning the best performing model and optionally displaying
26
+ # a contour map of performance.
27
+ #
28
+ # training_set:: instance of Problem, used for training
29
+ # cross_valn_set:: instance of Problem, used for evaluating models
30
+ # costs:: array of cost values to search across
31
+ # gammas:: array of gamma values to search across
32
+ # params:: Optional parameters include:
33
+ # * :evaluator => Evaluator::OverallAccuracy, the name of the class
34
+ # to use for computing performance
35
+ # * :show_plot => false, whether to display contour plot
36
+ #
37
+ # Returns an instance of Model, the best performing model.
38
+ #
39
+ def Svm.cross_validation_search(training_set, cross_valn_set,
40
+ costs = [-2,-1,0,1,2,3].collect {|i| 2**i},
41
+ gammas = [-2,-1,0,1,2,3].collect {|i| 2**i},
42
+ params = {})
43
+ evaluator = params.fetch :evaluator, Evaluator::OverallAccuracy
44
+ show_plot = params.fetch :show_plot, false
45
+
46
+ fjp = ForkJoinPool.new
47
+ task = CrossValidationSearch.new gammas, costs, training_set, cross_valn_set, evaluator
48
+ results, best_model = fjp.invoke task
49
+
50
+ if show_plot
51
+ ContourDisplay.new(costs.collect {|n| Math.log2(n)},
52
+ gammas.collect {|n| Math.log2(n)},
53
+ results)
54
+ end
55
+
56
+ return best_model
57
+ end
58
+
59
+ private
60
+ # Set up the cross validation search across a cost/gamma pair
61
+ class CrossValidationSearch < RecursiveTask
62
+ # Creates an instance of the CrossValidationSearch.
63
+ #
64
+ # gammas:: array of gamma values to search over
65
+ # costs:: array of cost values to search over
66
+ # training_set:: for building the model
67
+ # cross_valn_set:: for testing the model
68
+ # evaluator:: name of Evaluator class, used for evaluating the model
69
+ #
70
+ def initialize gammas, costs, training_set, cross_valn_set, evaluator
71
+ super()
72
+
73
+ @gammas = gammas
74
+ @costs = costs
75
+ @training_set = training_set
76
+ @cross_valn_set = cross_valn_set
77
+ @evaluator = evaluator
78
+ end
79
+
80
+ # perform actual computation, return results/best_model
81
+ def compute
82
+ tasks = []
83
+ # create one task per gamma/cost pair
84
+ @gammas.each do |gamma|
85
+ @costs.each do |cost|
86
+ tasks << SvmTrainer.new(@training_set, Parameter.new(
87
+ :svm_type => Parameter::C_SVC,
88
+ :kernel_type => Parameter::RBF,
89
+ :cost => cost,
90
+ :gamma => gamma
91
+ ), @cross_valn_set, @evaluator)
92
+ end
93
+ end
94
+
95
+ # set off all the tasks
96
+ tasks.each do |task|
97
+ task.fork
98
+ end
99
+
100
+ # collect the results
101
+ results = []
102
+ best_model = nil
103
+ lowest_error = nil
104
+
105
+ @gammas.each do |gamma|
106
+ results_row = []
107
+ @costs.each do |cost|
108
+ task = tasks.shift
109
+ model, result = task.join
110
+
111
+ if result.better_than? lowest_error
112
+ best_model = model
113
+ lowest_error = result
114
+ end
115
+ puts "Result for cost = #{cost} gamma = #{gamma} is #{result.value}"
116
+ results_row << result.value
117
+ end
118
+ results << results_row
119
+ end
120
+
121
+ return results, best_model
122
+ end
123
+ end
124
+
125
+ # Represent a single training task for an SVM RBF model
126
+ class SvmTrainer < RecursiveTask
127
+
128
+ # Creates an instance of an SvmTrainer.
129
+ #
130
+ # training_set:: used to train the model
131
+ # parameters:: parameters for building the model
132
+ # cross_valn_set:: used to test the model performance
133
+ # evaluator:: class name of Evaluator to use for evaluating the model performance
134
+ #
135
+ def initialize training_set, parameters, cross_valn_set, evaluator
136
+ super()
137
+
138
+ @training_set = training_set
139
+ @parameters = parameters
140
+ @cross_valn_set = cross_valn_set
141
+ @evaluator = evaluator
142
+ end
143
+
144
+ # Trains and evaluates a model, using the parameters.
145
+ #
146
+ # Returns the model and evaluation.
147
+ #
148
+ def compute
149
+ model = Svm.svm_train @training_set, @parameters
150
+ result = model.evaluate_dataset @cross_valn_set, :evaluator => @evaluator
151
+ return model, result
152
+ end
153
+ end
154
+
155
+ # Swing Frame displaying the cross-validation performance.
156
+ #
157
+ class ContourDisplay < javax.swing.JFrame
158
+ # Creates an instance of the ContourDisplay.
159
+ #
160
+ # xs:: array of x-coordinates
161
+ # ys:: array of y-coordinates
162
+ # zs:: array of values for matching (x, y) coordinate
163
+ #
164
+ def initialize(xs, ys, zs)
165
+ super("Cross-Validation Performance")
166
+ self.setSize(500, 400)
167
+
168
+ cxs = Java::double[][ys.size].new
169
+ cys = Java::double[][ys.size].new
170
+ ys.size.times do |i|
171
+ cxs[i] = Java::double[xs.size].new
172
+ cys[i] = Java::double[xs.size].new
173
+ xs.size.times do |j|
174
+ cxs[i][j] = xs[j]
175
+ cys[i][j] = ys[i]
176
+ end
177
+ end
178
+
179
+ czs = Java::double[][ys.size].new
180
+ ys.size.times do |i|
181
+ czs[i] = Java::double[xs.size].new
182
+ xs.size.times do |j|
183
+ czs[i][j] = zs[i][j]
184
+ end
185
+ end
186
+
187
+ plot = ContourPlot.new(
188
+ cxs,
189
+ cys,
190
+ czs,
191
+ 10,
192
+ false,
193
+ "",
194
+ "Cost (log-scale)",
195
+ "Gamma (log-scale)",
196
+ nil,
197
+ nil
198
+ )
199
+ plot.colorizeContours(java.awt::Color.green, java.awt::Color.red)
200
+
201
+ symbol = DiamondSymbol.new
202
+ symbol.border_color = java.awt::Color.blue
203
+ symbol.fill_color = java.awt::Color.blue
204
+ symbol.size = 4
205
+
206
+ run = PlotRun.new
207
+ ys.size.times do |i|
208
+ xs.size.times do |j|
209
+ run.add(PlotDatum.new(cxs[i][j], cys[i][j], false, symbol))
210
+ end
211
+ end
212
+
213
+ plot.runs << run
214
+
215
+ panel = PlotPanel.new(plot)
216
+ panel.background = java.awt::Color.white
217
+ add panel
218
+
219
+ self.setDefaultCloseOperation(javax.swing.WindowConstants::DISPOSE_ON_CLOSE)
220
+ self.visible = true
221
+ end
222
+ end
223
+ end
224
+ end
@@ -0,0 +1,37 @@
1
+
2
+ require 'confusion_matrix'
3
+ require 'csv'
4
+
5
+ require 'java'
6
+ require 'libsvm'
7
+ require 'PlotPackage'
8
+
9
+ # Containing module for this library.
10
+ #
11
+ module SvmToolkit
12
+ # import the required java classes - must do this before loading the ruby classes
13
+
14
+ java_import 'libsvm.Parameter'
15
+ java_import 'libsvm.Model'
16
+ java_import 'libsvm.Problem'
17
+ java_import 'libsvm.Node'
18
+ java_import 'libsvm.Svm'
19
+
20
+ java_import 'java.util.concurrent.ForkJoinPool'
21
+ java_import 'java.util.concurrent.RecursiveTask'
22
+
23
+ java_import 'jahuwaldt.plot.ContourPlot'
24
+ java_import 'jahuwaldt.plot.DiamondSymbol'
25
+ java_import 'jahuwaldt.plot.PlotDatum'
26
+ java_import 'jahuwaldt.plot.PlotPanel'
27
+ java_import 'jahuwaldt.plot.PlotRun'
28
+ end
29
+
30
+ # finally require the ruby code which extends the Java classes
31
+ require 'svm_toolkit/evaluators'
32
+ require 'svm_toolkit/model'
33
+ require 'svm_toolkit/node'
34
+ require 'svm_toolkit/parameter'
35
+ require 'svm_toolkit/problem'
36
+ require 'svm_toolkit/svm'
37
+
metadata ADDED
@@ -0,0 +1,79 @@
1
+ --- !ruby/object:Gem::Specification
2
+ name: svm_toolkit
3
+ version: !ruby/object:Gem::Version
4
+ version: 1.1.7
5
+ platform: java
6
+ authors:
7
+ - Peter Lane
8
+ autorequire:
9
+ bindir: bin
10
+ cert_chain: []
11
+ date: 2023-02-05 00:00:00.000000000 Z
12
+ dependencies:
13
+ - !ruby/object:Gem::Dependency
14
+ requirement: !ruby/object:Gem::Requirement
15
+ requirements:
16
+ - - "~>"
17
+ - !ruby/object:Gem::Version
18
+ version: '1.0'
19
+ name: confusion_matrix
20
+ prerelease: false
21
+ type: :runtime
22
+ version_requirements: !ruby/object:Gem::Requirement
23
+ requirements:
24
+ - - "~>"
25
+ - !ruby/object:Gem::Version
26
+ version: '1.0'
27
+ description: "Support-vector machines are a popular tool in data mining. This package\
28
+ \ includes an amended version of the Java implementation of the libsvm library (version\
29
+ \ 3.11). Additional methods and examples are provided to support standard training\
30
+ \ techniques, such as cross-validation, various alternative evaluation methods,\
31
+ \ such as overall accuracy, precision or recall, and simple visualisations. \n"
32
+ email: peterlane@gmx.com
33
+ executables:
34
+ - svm-demo
35
+ extensions: []
36
+ extra_rdoc_files:
37
+ - README.rdoc
38
+ - LICENSE.rdoc
39
+ files:
40
+ - LICENSE.rdoc
41
+ - README.rdoc
42
+ - bin/svm-demo
43
+ - lib/PlotPackage.jar
44
+ - lib/libsvm.jar
45
+ - lib/svm_toolkit.rb
46
+ - lib/svm_toolkit/evaluators.rb
47
+ - lib/svm_toolkit/model.rb
48
+ - lib/svm_toolkit/node.rb
49
+ - lib/svm_toolkit/parameter.rb
50
+ - lib/svm_toolkit/problem.rb
51
+ - lib/svm_toolkit/svm.rb
52
+ homepage:
53
+ licenses:
54
+ - MIT
55
+ metadata: {}
56
+ post_install_message: "'svm-demo' should now be available via your jruby path. If\
57
+ \ not, add \nC:/jruby-9.4.0.0/lib/ruby/gems/shared/gems/svm_toolkit-1.1.7-java/bin\n\
58
+ \ to your path."
59
+ rdoc_options:
60
+ - "-m"
61
+ - README.rdoc
62
+ require_paths:
63
+ - lib
64
+ required_ruby_version: !ruby/object:Gem::Requirement
65
+ requirements:
66
+ - - ">="
67
+ - !ruby/object:Gem::Version
68
+ version: '0'
69
+ required_rubygems_version: !ruby/object:Gem::Requirement
70
+ requirements:
71
+ - - ">="
72
+ - !ruby/object:Gem::Version
73
+ version: '0'
74
+ requirements: []
75
+ rubygems_version: 3.3.25
76
+ signing_key:
77
+ specification_version: 4
78
+ summary: A JRuby wrapper around the libsvm library, with additional functionality.
79
+ test_files: []