svm_toolkit 1.1.7-java → 1.1.8-java

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,224 +1,219 @@
1
- module SvmToolkit
2
-
3
- # Extends the Java SVM class
4
- #
5
- # Available methods include:
6
- #
7
- # Svm.svm_train(problem, param)
8
- #
9
- # problem:: instance of Problem
10
- # param:: instance of Parameter
11
- #
12
- # Returns an instance of Model
13
- #
14
- # Svm.svm_cross_validation(problem, param, nr_folds, target)
15
- #
16
- # problem:: instance of Problem
17
- # param:: instance of Parameter
18
- # nr_fold:: number of folds
19
- # target:: resulting predictions in an array
20
- #
21
- class Svm
22
-
23
- # Perform cross validation search on given gamma/cost values,
24
- # using an RBF kernel,
25
- # returning the best performing model and optionally displaying
26
- # a contour map of performance.
27
- #
28
- # training_set:: instance of Problem, used for training
29
- # cross_valn_set:: instance of Problem, used for evaluating models
30
- # costs:: array of cost values to search across
31
- # gammas:: array of gamma values to search across
32
- # params:: Optional parameters include:
33
- # * :evaluator => Evaluator::OverallAccuracy, the name of the class
34
- # to use for computing performance
35
- # * :show_plot => false, whether to display contour plot
36
- #
37
- # Returns an instance of Model, the best performing model.
38
- #
39
- def Svm.cross_validation_search(training_set, cross_valn_set,
40
- costs = [-2,-1,0,1,2,3].collect {|i| 2**i},
41
- gammas = [-2,-1,0,1,2,3].collect {|i| 2**i},
42
- params = {})
43
- evaluator = params.fetch :evaluator, Evaluator::OverallAccuracy
44
- show_plot = params.fetch :show_plot, false
45
-
46
- fjp = ForkJoinPool.new
47
- task = CrossValidationSearch.new gammas, costs, training_set, cross_valn_set, evaluator
48
- results, best_model = fjp.invoke task
49
-
50
- if show_plot
51
- ContourDisplay.new(costs.collect {|n| Math.log2(n)},
52
- gammas.collect {|n| Math.log2(n)},
53
- results)
54
- end
55
-
56
- return best_model
57
- end
58
-
59
- private
60
- # Set up the cross validation search across a cost/gamma pair
61
- class CrossValidationSearch < RecursiveTask
62
- # Creates an instance of the CrossValidationSearch.
63
- #
64
- # gammas:: array of gamma values to search over
65
- # costs:: array of cost values to search over
66
- # training_set:: for building the model
67
- # cross_valn_set:: for testing the model
68
- # evaluator:: name of Evaluator class, used for evaluating the model
69
- #
70
- def initialize gammas, costs, training_set, cross_valn_set, evaluator
71
- super()
72
-
73
- @gammas = gammas
74
- @costs = costs
75
- @training_set = training_set
76
- @cross_valn_set = cross_valn_set
77
- @evaluator = evaluator
78
- end
79
-
80
- # perform actual computation, return results/best_model
81
- def compute
82
- tasks = []
83
- # create one task per gamma/cost pair
84
- @gammas.each do |gamma|
85
- @costs.each do |cost|
86
- tasks << SvmTrainer.new(@training_set, Parameter.new(
87
- :svm_type => Parameter::C_SVC,
88
- :kernel_type => Parameter::RBF,
89
- :cost => cost,
90
- :gamma => gamma
91
- ), @cross_valn_set, @evaluator)
92
- end
93
- end
94
-
95
- # set off all the tasks
96
- tasks.each do |task|
97
- task.fork
98
- end
99
-
100
- # collect the results
101
- results = []
102
- best_model = nil
103
- lowest_error = nil
104
-
105
- @gammas.each do |gamma|
106
- results_row = []
107
- @costs.each do |cost|
108
- task = tasks.shift
109
- model, result = task.join
110
-
111
- if result.better_than? lowest_error
112
- best_model = model
113
- lowest_error = result
114
- end
115
- puts "Result for cost = #{cost} gamma = #{gamma} is #{result.value}"
116
- results_row << result.value
117
- end
118
- results << results_row
119
- end
120
-
121
- return results, best_model
122
- end
123
- end
124
-
125
- # Represent a single training task for an SVM RBF model
126
- class SvmTrainer < RecursiveTask
127
-
128
- # Creates an instance of an SvmTrainer.
129
- #
130
- # training_set:: used to train the model
131
- # parameters:: parameters for building the model
132
- # cross_valn_set:: used to test the model performance
133
- # evaluator:: class name of Evaluator to use for evaluating the model performance
134
- #
135
- def initialize training_set, parameters, cross_valn_set, evaluator
136
- super()
137
-
138
- @training_set = training_set
139
- @parameters = parameters
140
- @cross_valn_set = cross_valn_set
141
- @evaluator = evaluator
142
- end
143
-
144
- # Trains and evaluates a model, using the parameters.
145
- #
146
- # Returns the model and evaluation.
147
- #
148
- def compute
149
- model = Svm.svm_train @training_set, @parameters
150
- result = model.evaluate_dataset @cross_valn_set, :evaluator => @evaluator
151
- return model, result
152
- end
153
- end
154
-
155
- # Swing Frame displaying the cross-validation performance.
156
- #
157
- class ContourDisplay < javax.swing.JFrame
158
- # Creates an instance of the ContourDisplay.
159
- #
160
- # xs:: array of x-coordinates
161
- # ys:: array of y-coordinates
162
- # zs:: array of values for matching (x, y) coordinate
163
- #
164
- def initialize(xs, ys, zs)
165
- super("Cross-Validation Performance")
166
- self.setSize(500, 400)
167
-
168
- cxs = Java::double[][ys.size].new
169
- cys = Java::double[][ys.size].new
170
- ys.size.times do |i|
171
- cxs[i] = Java::double[xs.size].new
172
- cys[i] = Java::double[xs.size].new
173
- xs.size.times do |j|
174
- cxs[i][j] = xs[j]
175
- cys[i][j] = ys[i]
176
- end
177
- end
178
-
179
- czs = Java::double[][ys.size].new
180
- ys.size.times do |i|
181
- czs[i] = Java::double[xs.size].new
182
- xs.size.times do |j|
183
- czs[i][j] = zs[i][j]
184
- end
185
- end
186
-
187
- plot = ContourPlot.new(
188
- cxs,
189
- cys,
190
- czs,
191
- 10,
192
- false,
193
- "",
194
- "Cost (log-scale)",
195
- "Gamma (log-scale)",
196
- nil,
197
- nil
198
- )
199
- plot.colorizeContours(java.awt::Color.green, java.awt::Color.red)
200
-
201
- symbol = DiamondSymbol.new
202
- symbol.border_color = java.awt::Color.blue
203
- symbol.fill_color = java.awt::Color.blue
204
- symbol.size = 4
205
-
206
- run = PlotRun.new
207
- ys.size.times do |i|
208
- xs.size.times do |j|
209
- run.add(PlotDatum.new(cxs[i][j], cys[i][j], false, symbol))
210
- end
211
- end
212
-
213
- plot.runs << run
214
-
215
- panel = PlotPanel.new(plot)
216
- panel.background = java.awt::Color.white
217
- add panel
218
-
219
- self.setDefaultCloseOperation(javax.swing.WindowConstants::DISPOSE_ON_CLOSE)
220
- self.visible = true
221
- end
222
- end
223
- end
224
- end
1
+ module SvmToolkit
2
+
3
+ # Extends the Java SVM class.
4
+ class Svm
5
+
6
+ # Using given Problem and Parameter instances, builds and trains
7
+ # an SVM model.
8
+ def self.train(problem, parameters)
9
+ Svm.svm_train(problem, parameters)
10
+ end
11
+
12
+ # Using given Problem and Parameter instances, builds and trains
13
+ # SVM models using n-fold cross-validation.
14
+ def self.cross_validation(problem, parameters, n_folds, target)
15
+ Svm.svm_cross_validation(problem, parameters, n_folds, target)
16
+ end
17
+
18
+ # Perform cross validation search on given gamma/cost values,
19
+ # using an RBF kernel,
20
+ # returning the best performing model and optionally displaying
21
+ # a contour map of performance.
22
+ #
23
+ # * training_set - instance of Problem, used for training
24
+ # * cross_valn_set - instance of Problem, used for evaluating models
25
+ # * costs - array of cost values to search across
26
+ # * gammas - array of gamma values to search across
27
+ # * params - Optional parameters include:
28
+ # * :evaluator => Evaluator::OverallAccuracy, the name of the class
29
+ # to use for computing performance
30
+ # * :show_plot => false, whether to display contour plot
31
+ #
32
+ # Returns an instance of Model, the best performing model.
33
+ #
34
+ def Svm.cross_validation_search(training_set, cross_valn_set,
35
+ costs = [-2,-1,0,1,2,3].collect {|i| 2**i},
36
+ gammas = [-2,-1,0,1,2,3].collect {|i| 2**i},
37
+ params = {})
38
+ evaluator = params.fetch :evaluator, Evaluator::OverallAccuracy
39
+ show_plot = params.fetch :show_plot, false
40
+
41
+ fjp = ForkJoinPool.new
42
+ task = CrossValidationSearch.new gammas, costs, training_set, cross_valn_set, evaluator
43
+ results, best_model = fjp.invoke task
44
+
45
+ if show_plot
46
+ ContourDisplay.new(costs.collect {|n| Math.log2(n)},
47
+ gammas.collect {|n| Math.log2(n)},
48
+ results)
49
+ end
50
+
51
+ return best_model
52
+ end
53
+
54
+ private
55
+ # Set up the cross validation search across a cost/gamma pair
56
+ class CrossValidationSearch < RecursiveTask
57
+ # Creates an instance of the CrossValidationSearch.
58
+ #
59
+ # * gammas - array of gamma values to search over
60
+ # * costs - array of cost values to search over
61
+ # * training_set - for building the model
62
+ # * cross_valn_set - for testing the model
63
+ # * evaluator - name of Evaluator class, used for evaluating the model
64
+ #
65
+ def initialize gammas, costs, training_set, cross_valn_set, evaluator
66
+ super()
67
+
68
+ @gammas = gammas
69
+ @costs = costs
70
+ @training_set = training_set
71
+ @cross_valn_set = cross_valn_set
72
+ @evaluator = evaluator
73
+ end
74
+
75
+ # perform actual computation, return results/best_model
76
+ def compute
77
+ tasks = []
78
+ # create one task per gamma/cost pair
79
+ @gammas.each do |gamma|
80
+ @costs.each do |cost|
81
+ tasks << SvmTrainer.new(@training_set, Parameter.new(
82
+ :svm_type => Parameter::C_SVC,
83
+ :kernel_type => Parameter::RBF,
84
+ :cost => cost,
85
+ :gamma => gamma
86
+ ), @cross_valn_set, @evaluator)
87
+ end
88
+ end
89
+
90
+ # set off all the tasks
91
+ tasks.each do |task|
92
+ task.fork
93
+ end
94
+
95
+ # collect the results
96
+ results = []
97
+ best_model = nil
98
+ lowest_error = nil
99
+
100
+ @gammas.each do |gamma|
101
+ results_row = []
102
+ @costs.each do |cost|
103
+ task = tasks.shift
104
+ model, result = task.join
105
+
106
+ if result.better_than? lowest_error
107
+ best_model = model
108
+ lowest_error = result
109
+ end
110
+ puts "Result for cost = #{cost} gamma = #{gamma} is #{result.value}"
111
+ results_row << result.value
112
+ end
113
+ results << results_row
114
+ end
115
+
116
+ return results, best_model
117
+ end
118
+ end
119
+
120
+ # Represent a single training task for an SVM RBF model
121
+ class SvmTrainer < RecursiveTask
122
+
123
+ # Creates an instance of an SvmTrainer.
124
+ #
125
+ # * training_set - used to train the model
126
+ # * parameters - parameters for building the model
127
+ # * cross_valn_set - used to test the model performance
128
+ # * evaluator - class name of Evaluator to use for evaluating the model performance
129
+ #
130
+ def initialize training_set, parameters, cross_valn_set, evaluator
131
+ super()
132
+
133
+ @training_set = training_set
134
+ @parameters = parameters
135
+ @cross_valn_set = cross_valn_set
136
+ @evaluator = evaluator
137
+ end
138
+
139
+ # Trains and evaluates a model, using the parameters.
140
+ #
141
+ # Returns the model and evaluation.
142
+ #
143
+ def compute
144
+ model = Svm.svm_train @training_set, @parameters
145
+ result = model.evaluate_dataset @cross_valn_set, :evaluator => @evaluator
146
+ return model, result
147
+ end
148
+ end
149
+
150
+ # Swing Frame displaying the cross-validation performance.
151
+ #
152
+ class ContourDisplay < javax.swing.JFrame
153
+ # Creates an instance of the ContourDisplay.
154
+ #
155
+ # * xs - array of x-coordinates
156
+ # * ys - array of y-coordinates
157
+ # * zs - array of values for matching (x, y) coordinate
158
+ #
159
+ def initialize(xs, ys, zs)
160
+ super("Cross-Validation Performance")
161
+ self.setSize(500, 400)
162
+
163
+ cxs = Java::double[][ys.size].new
164
+ cys = Java::double[][ys.size].new
165
+ ys.size.times do |i|
166
+ cxs[i] = Java::double[xs.size].new
167
+ cys[i] = Java::double[xs.size].new
168
+ xs.size.times do |j|
169
+ cxs[i][j] = xs[j]
170
+ cys[i][j] = ys[i]
171
+ end
172
+ end
173
+
174
+ czs = Java::double[][ys.size].new
175
+ ys.size.times do |i|
176
+ czs[i] = Java::double[xs.size].new
177
+ xs.size.times do |j|
178
+ czs[i][j] = zs[i][j]
179
+ end
180
+ end
181
+
182
+ plot = ContourPlot.new(
183
+ cxs,
184
+ cys,
185
+ czs,
186
+ 10,
187
+ false,
188
+ "",
189
+ "Cost (log-scale)",
190
+ "Gamma (log-scale)",
191
+ nil,
192
+ nil
193
+ )
194
+ plot.colorizeContours(java.awt::Color.green, java.awt::Color.red)
195
+
196
+ symbol = DiamondSymbol.new
197
+ symbol.border_color = java.awt::Color.blue
198
+ symbol.fill_color = java.awt::Color.blue
199
+ symbol.size = 4
200
+
201
+ run = PlotRun.new
202
+ ys.size.times do |i|
203
+ xs.size.times do |j|
204
+ run.add(PlotDatum.new(cxs[i][j], cys[i][j], false, symbol))
205
+ end
206
+ end
207
+
208
+ plot.runs << run
209
+
210
+ panel = PlotPanel.new(plot)
211
+ panel.background = java.awt::Color.white
212
+ add panel
213
+
214
+ self.setDefaultCloseOperation(javax.swing.WindowConstants::DISPOSE_ON_CLOSE)
215
+ self.visible = true
216
+ end
217
+ end
218
+ end
219
+ end
data/lib/svm_toolkit.rb CHANGED
@@ -1,37 +1,37 @@
1
-
2
- require 'confusion_matrix'
3
- require 'csv'
4
-
5
- require 'java'
6
- require 'libsvm'
7
- require 'PlotPackage'
8
-
9
- # Containing module for this library.
10
- #
11
- module SvmToolkit
12
- # import the required java classes - must do this before loading the ruby classes
13
-
14
- java_import 'libsvm.Parameter'
15
- java_import 'libsvm.Model'
16
- java_import 'libsvm.Problem'
17
- java_import 'libsvm.Node'
18
- java_import 'libsvm.Svm'
19
-
20
- java_import 'java.util.concurrent.ForkJoinPool'
21
- java_import 'java.util.concurrent.RecursiveTask'
22
-
23
- java_import 'jahuwaldt.plot.ContourPlot'
24
- java_import 'jahuwaldt.plot.DiamondSymbol'
25
- java_import 'jahuwaldt.plot.PlotDatum'
26
- java_import 'jahuwaldt.plot.PlotPanel'
27
- java_import 'jahuwaldt.plot.PlotRun'
28
- end
29
-
30
- # finally require the ruby code which extends the Java classes
31
- require 'svm_toolkit/evaluators'
32
- require 'svm_toolkit/model'
33
- require 'svm_toolkit/node'
34
- require 'svm_toolkit/parameter'
35
- require 'svm_toolkit/problem'
36
- require 'svm_toolkit/svm'
37
-
1
+
2
+ require "confusion_matrix"
3
+ require "csv"
4
+
5
+ require "java"
6
+ require "libsvm"
7
+ require "PlotPackage"
8
+
9
+ # Containing module for this library.
10
+ #
11
+ module SvmToolkit
12
+ # import the required java classes - must do this before loading the ruby classes
13
+
14
+ java_import "libsvm.Parameter"
15
+ java_import "libsvm.Model"
16
+ java_import "libsvm.Problem"
17
+ java_import "libsvm.Node"
18
+ java_import "libsvm.Svm"
19
+
20
+ java_import "java.util.concurrent.ForkJoinPool"
21
+ java_import "java.util.concurrent.RecursiveTask"
22
+
23
+ java_import "jahuwaldt.plot.ContourPlot"
24
+ java_import "jahuwaldt.plot.DiamondSymbol"
25
+ java_import "jahuwaldt.plot.PlotDatum"
26
+ java_import "jahuwaldt.plot.PlotPanel"
27
+ java_import "jahuwaldt.plot.PlotRun"
28
+ end
29
+
30
+ # finally require the ruby code which extends the Java classes
31
+ require "svm_toolkit/evaluators"
32
+ require "svm_toolkit/model"
33
+ require "svm_toolkit/node"
34
+ require "svm_toolkit/parameter"
35
+ require "svm_toolkit/problem"
36
+ require "svm_toolkit/svm"
37
+
metadata CHANGED
@@ -1,29 +1,42 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: svm_toolkit
3
3
  version: !ruby/object:Gem::Version
4
- version: 1.1.7
4
+ version: 1.1.8
5
5
  platform: java
6
6
  authors:
7
7
  - Peter Lane
8
- autorequire:
9
8
  bindir: bin
10
9
  cert_chain: []
11
- date: 2023-02-05 00:00:00.000000000 Z
10
+ date: 1980-01-02 00:00:00.000000000 Z
12
11
  dependencies:
13
12
  - !ruby/object:Gem::Dependency
13
+ name: confusion_matrix
14
14
  requirement: !ruby/object:Gem::Requirement
15
15
  requirements:
16
16
  - - "~>"
17
17
  - !ruby/object:Gem::Version
18
18
  version: '1.0'
19
- name: confusion_matrix
20
- prerelease: false
21
19
  type: :runtime
20
+ prerelease: false
22
21
  version_requirements: !ruby/object:Gem::Requirement
23
22
  requirements:
24
23
  - - "~>"
25
24
  - !ruby/object:Gem::Version
26
25
  version: '1.0'
26
+ - !ruby/object:Gem::Dependency
27
+ name: csv
28
+ requirement: !ruby/object:Gem::Requirement
29
+ requirements:
30
+ - - "~>"
31
+ - !ruby/object:Gem::Version
32
+ version: '3.3'
33
+ type: :runtime
34
+ prerelease: false
35
+ version_requirements: !ruby/object:Gem::Requirement
36
+ requirements:
37
+ - - "~>"
38
+ - !ruby/object:Gem::Version
39
+ version: '3.3'
27
40
  description: "Support-vector machines are a popular tool in data mining. This package\
28
41
  \ includes an amended version of the Java implementation of the libsvm library (version\
29
42
  \ 3.11). Additional methods and examples are provided to support standard training\
@@ -34,8 +47,8 @@ executables:
34
47
  - svm-demo
35
48
  extensions: []
36
49
  extra_rdoc_files:
37
- - README.rdoc
38
50
  - LICENSE.rdoc
51
+ - README.rdoc
39
52
  files:
40
53
  - LICENSE.rdoc
41
54
  - README.rdoc
@@ -49,12 +62,14 @@ files:
49
62
  - lib/svm_toolkit/parameter.rb
50
63
  - lib/svm_toolkit/problem.rb
51
64
  - lib/svm_toolkit/svm.rb
52
- homepage:
65
+ homepage: https://rubygems.org/gems/svm_toolkit
53
66
  licenses:
54
67
  - MIT
55
- metadata: {}
68
+ metadata:
69
+ documentation_uri: https://rubydoc.info/gems/svm_toolkit
70
+ source_code_uri: https://codeberg.org/peterlane/svm_toolkit
56
71
  post_install_message: "'svm-demo' should now be available via your jruby path. If\
57
- \ not, add \nC:/jruby-9.4.0.0/lib/ruby/gems/shared/gems/svm_toolkit-1.1.7-java/bin\n\
72
+ \ not, add \n/home/peter/.rbenv/versions/jruby-10.0.2.0/lib/ruby/gems/shared/gems/svm_toolkit-1.1.8-java/bin\n\
58
73
  \ to your path."
59
74
  rdoc_options:
60
75
  - "-m"
@@ -65,15 +80,14 @@ required_ruby_version: !ruby/object:Gem::Requirement
65
80
  requirements:
66
81
  - - ">="
67
82
  - !ruby/object:Gem::Version
68
- version: '0'
83
+ version: '3.4'
69
84
  required_rubygems_version: !ruby/object:Gem::Requirement
70
85
  requirements:
71
86
  - - ">="
72
87
  - !ruby/object:Gem::Version
73
88
  version: '0'
74
89
  requirements: []
75
- rubygems_version: 3.3.25
76
- signing_key:
90
+ rubygems_version: 3.6.9
77
91
  specification_version: 4
78
92
  summary: A JRuby wrapper around the libsvm library, with additional functionality.
79
93
  test_files: []