svm_toolkit 1.1.7-java

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,224 @@
1
+ module SvmToolkit
2
+
3
+ # Extends the Java SVM class
4
+ #
5
+ # Available methods include:
6
+ #
7
+ # Svm.svm_train(problem, param)
8
+ #
9
+ # problem:: instance of Problem
10
+ # param:: instance of Parameter
11
+ #
12
+ # Returns an instance of Model
13
+ #
14
+ # Svm.svm_cross_validation(problem, param, nr_folds, target)
15
+ #
16
+ # problem:: instance of Problem
17
+ # param:: instance of Parameter
18
+ # nr_fold:: number of folds
19
+ # target:: resulting predictions in an array
20
+ #
21
+ class Svm
22
+
23
+ # Perform cross validation search on given gamma/cost values,
24
+ # using an RBF kernel,
25
+ # returning the best performing model and optionally displaying
26
+ # a contour map of performance.
27
+ #
28
+ # training_set:: instance of Problem, used for training
29
+ # cross_valn_set:: instance of Problem, used for evaluating models
30
+ # costs:: array of cost values to search across
31
+ # gammas:: array of gamma values to search across
32
+ # params:: Optional parameters include:
33
+ # * :evaluator => Evaluator::OverallAccuracy, the name of the class
34
+ # to use for computing performance
35
+ # * :show_plot => false, whether to display contour plot
36
+ #
37
+ # Returns an instance of Model, the best performing model.
38
+ #
39
+ def Svm.cross_validation_search(training_set, cross_valn_set,
40
+ costs = [-2,-1,0,1,2,3].collect {|i| 2**i},
41
+ gammas = [-2,-1,0,1,2,3].collect {|i| 2**i},
42
+ params = {})
43
+ evaluator = params.fetch :evaluator, Evaluator::OverallAccuracy
44
+ show_plot = params.fetch :show_plot, false
45
+
46
+ fjp = ForkJoinPool.new
47
+ task = CrossValidationSearch.new gammas, costs, training_set, cross_valn_set, evaluator
48
+ results, best_model = fjp.invoke task
49
+
50
+ if show_plot
51
+ ContourDisplay.new(costs.collect {|n| Math.log2(n)},
52
+ gammas.collect {|n| Math.log2(n)},
53
+ results)
54
+ end
55
+
56
+ return best_model
57
+ end
58
+
59
+ private
60
+ # Set up the cross validation search across a cost/gamma pair
61
+ class CrossValidationSearch < RecursiveTask
62
+ # Creates an instance of the CrossValidationSearch.
63
+ #
64
+ # gammas:: array of gamma values to search over
65
+ # costs:: array of cost values to search over
66
+ # training_set:: for building the model
67
+ # cross_valn_set:: for testing the model
68
+ # evaluator:: name of Evaluator class, used for evaluating the model
69
+ #
70
+ def initialize gammas, costs, training_set, cross_valn_set, evaluator
71
+ super()
72
+
73
+ @gammas = gammas
74
+ @costs = costs
75
+ @training_set = training_set
76
+ @cross_valn_set = cross_valn_set
77
+ @evaluator = evaluator
78
+ end
79
+
80
+ # perform actual computation, return results/best_model
81
+ def compute
82
+ tasks = []
83
+ # create one task per gamma/cost pair
84
+ @gammas.each do |gamma|
85
+ @costs.each do |cost|
86
+ tasks << SvmTrainer.new(@training_set, Parameter.new(
87
+ :svm_type => Parameter::C_SVC,
88
+ :kernel_type => Parameter::RBF,
89
+ :cost => cost,
90
+ :gamma => gamma
91
+ ), @cross_valn_set, @evaluator)
92
+ end
93
+ end
94
+
95
+ # set off all the tasks
96
+ tasks.each do |task|
97
+ task.fork
98
+ end
99
+
100
+ # collect the results
101
+ results = []
102
+ best_model = nil
103
+ lowest_error = nil
104
+
105
+ @gammas.each do |gamma|
106
+ results_row = []
107
+ @costs.each do |cost|
108
+ task = tasks.shift
109
+ model, result = task.join
110
+
111
+ if result.better_than? lowest_error
112
+ best_model = model
113
+ lowest_error = result
114
+ end
115
+ puts "Result for cost = #{cost} gamma = #{gamma} is #{result.value}"
116
+ results_row << result.value
117
+ end
118
+ results << results_row
119
+ end
120
+
121
+ return results, best_model
122
+ end
123
+ end
124
+
125
+ # Represent a single training task for an SVM RBF model
126
+ class SvmTrainer < RecursiveTask
127
+
128
+ # Creates an instance of an SvmTrainer.
129
+ #
130
+ # training_set:: used to train the model
131
+ # parameters:: parameters for building the model
132
+ # cross_valn_set:: used to test the model performance
133
+ # evaluator:: class name of Evaluator to use for evaluating the model performance
134
+ #
135
+ def initialize training_set, parameters, cross_valn_set, evaluator
136
+ super()
137
+
138
+ @training_set = training_set
139
+ @parameters = parameters
140
+ @cross_valn_set = cross_valn_set
141
+ @evaluator = evaluator
142
+ end
143
+
144
+ # Trains and evaluates a model, using the parameters.
145
+ #
146
+ # Returns the model and evaluation.
147
+ #
148
+ def compute
149
+ model = Svm.svm_train @training_set, @parameters
150
+ result = model.evaluate_dataset @cross_valn_set, :evaluator => @evaluator
151
+ return model, result
152
+ end
153
+ end
154
+
155
+ # Swing Frame displaying the cross-validation performance.
156
+ #
157
+ class ContourDisplay < javax.swing.JFrame
158
+ # Creates an instance of the ContourDisplay.
159
+ #
160
+ # xs:: array of x-coordinates
161
+ # ys:: array of y-coordinates
162
+ # zs:: array of values for matching (x, y) coordinate
163
+ #
164
+ def initialize(xs, ys, zs)
165
+ super("Cross-Validation Performance")
166
+ self.setSize(500, 400)
167
+
168
+ cxs = Java::double[][ys.size].new
169
+ cys = Java::double[][ys.size].new
170
+ ys.size.times do |i|
171
+ cxs[i] = Java::double[xs.size].new
172
+ cys[i] = Java::double[xs.size].new
173
+ xs.size.times do |j|
174
+ cxs[i][j] = xs[j]
175
+ cys[i][j] = ys[i]
176
+ end
177
+ end
178
+
179
+ czs = Java::double[][ys.size].new
180
+ ys.size.times do |i|
181
+ czs[i] = Java::double[xs.size].new
182
+ xs.size.times do |j|
183
+ czs[i][j] = zs[i][j]
184
+ end
185
+ end
186
+
187
+ plot = ContourPlot.new(
188
+ cxs,
189
+ cys,
190
+ czs,
191
+ 10,
192
+ false,
193
+ "",
194
+ "Cost (log-scale)",
195
+ "Gamma (log-scale)",
196
+ nil,
197
+ nil
198
+ )
199
+ plot.colorizeContours(java.awt::Color.green, java.awt::Color.red)
200
+
201
+ symbol = DiamondSymbol.new
202
+ symbol.border_color = java.awt::Color.blue
203
+ symbol.fill_color = java.awt::Color.blue
204
+ symbol.size = 4
205
+
206
+ run = PlotRun.new
207
+ ys.size.times do |i|
208
+ xs.size.times do |j|
209
+ run.add(PlotDatum.new(cxs[i][j], cys[i][j], false, symbol))
210
+ end
211
+ end
212
+
213
+ plot.runs << run
214
+
215
+ panel = PlotPanel.new(plot)
216
+ panel.background = java.awt::Color.white
217
+ add panel
218
+
219
+ self.setDefaultCloseOperation(javax.swing.WindowConstants::DISPOSE_ON_CLOSE)
220
+ self.visible = true
221
+ end
222
+ end
223
+ end
224
+ end
@@ -0,0 +1,37 @@
1
+
2
+ require 'confusion_matrix'
3
+ require 'csv'
4
+
5
+ require 'java'
6
+ require 'libsvm'
7
+ require 'PlotPackage'
8
+
9
+ # Containing module for this library.
10
+ #
11
+ module SvmToolkit
12
+ # import the required java classes - must do this before loading the ruby classes
13
+
14
+ java_import 'libsvm.Parameter'
15
+ java_import 'libsvm.Model'
16
+ java_import 'libsvm.Problem'
17
+ java_import 'libsvm.Node'
18
+ java_import 'libsvm.Svm'
19
+
20
+ java_import 'java.util.concurrent.ForkJoinPool'
21
+ java_import 'java.util.concurrent.RecursiveTask'
22
+
23
+ java_import 'jahuwaldt.plot.ContourPlot'
24
+ java_import 'jahuwaldt.plot.DiamondSymbol'
25
+ java_import 'jahuwaldt.plot.PlotDatum'
26
+ java_import 'jahuwaldt.plot.PlotPanel'
27
+ java_import 'jahuwaldt.plot.PlotRun'
28
+ end
29
+
30
+ # finally require the ruby code which extends the Java classes
31
+ require 'svm_toolkit/evaluators'
32
+ require 'svm_toolkit/model'
33
+ require 'svm_toolkit/node'
34
+ require 'svm_toolkit/parameter'
35
+ require 'svm_toolkit/problem'
36
+ require 'svm_toolkit/svm'
37
+
metadata ADDED
@@ -0,0 +1,79 @@
1
+ --- !ruby/object:Gem::Specification
2
+ name: svm_toolkit
3
+ version: !ruby/object:Gem::Version
4
+ version: 1.1.7
5
+ platform: java
6
+ authors:
7
+ - Peter Lane
8
+ autorequire:
9
+ bindir: bin
10
+ cert_chain: []
11
+ date: 2023-02-05 00:00:00.000000000 Z
12
+ dependencies:
13
+ - !ruby/object:Gem::Dependency
14
+ requirement: !ruby/object:Gem::Requirement
15
+ requirements:
16
+ - - "~>"
17
+ - !ruby/object:Gem::Version
18
+ version: '1.0'
19
+ name: confusion_matrix
20
+ prerelease: false
21
+ type: :runtime
22
+ version_requirements: !ruby/object:Gem::Requirement
23
+ requirements:
24
+ - - "~>"
25
+ - !ruby/object:Gem::Version
26
+ version: '1.0'
27
+ description: "Support-vector machines are a popular tool in data mining. This package\
28
+ \ includes an amended version of the Java implementation of the libsvm library (version\
29
+ \ 3.11). Additional methods and examples are provided to support standard training\
30
+ \ techniques, such as cross-validation, various alternative evaluation methods,\
31
+ \ such as overall accuracy, precision or recall, and simple visualisations. \n"
32
+ email: peterlane@gmx.com
33
+ executables:
34
+ - svm-demo
35
+ extensions: []
36
+ extra_rdoc_files:
37
+ - README.rdoc
38
+ - LICENSE.rdoc
39
+ files:
40
+ - LICENSE.rdoc
41
+ - README.rdoc
42
+ - bin/svm-demo
43
+ - lib/PlotPackage.jar
44
+ - lib/libsvm.jar
45
+ - lib/svm_toolkit.rb
46
+ - lib/svm_toolkit/evaluators.rb
47
+ - lib/svm_toolkit/model.rb
48
+ - lib/svm_toolkit/node.rb
49
+ - lib/svm_toolkit/parameter.rb
50
+ - lib/svm_toolkit/problem.rb
51
+ - lib/svm_toolkit/svm.rb
52
+ homepage:
53
+ licenses:
54
+ - MIT
55
+ metadata: {}
56
+ post_install_message: "'svm-demo' should now be available via your jruby path. If\
57
+ \ not, add \nC:/jruby-9.4.0.0/lib/ruby/gems/shared/gems/svm_toolkit-1.1.7-java/bin\n\
58
+ \ to your path."
59
+ rdoc_options:
60
+ - "-m"
61
+ - README.rdoc
62
+ require_paths:
63
+ - lib
64
+ required_ruby_version: !ruby/object:Gem::Requirement
65
+ requirements:
66
+ - - ">="
67
+ - !ruby/object:Gem::Version
68
+ version: '0'
69
+ required_rubygems_version: !ruby/object:Gem::Requirement
70
+ requirements:
71
+ - - ">="
72
+ - !ruby/object:Gem::Version
73
+ version: '0'
74
+ requirements: []
75
+ rubygems_version: 3.3.25
76
+ signing_key:
77
+ specification_version: 4
78
+ summary: A JRuby wrapper around the libsvm library, with additional functionality.
79
+ test_files: []