svm_toolkit 1.1.7-java

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,124 @@
1
+ module SvmToolkit
2
+
3
+ # Extends the Java Model class with some additional methods.
4
+ #
5
+ class Model
6
+
7
+ # Evaluate model on given data set (an instance of Problem),
8
+ # returning the number of errors made.
9
+ # Optional parameters include:
10
+ # * :evaluator => Evaluator::OverallAccuracy, the name of the class to use for computing performance
11
+ # * :print_results => false, whether to print the result for each instance
12
+ def evaluate_dataset(data, params = {})
13
+ evaluator = params.fetch(:evaluator, Evaluator::OverallAccuracy)
14
+ print_results = params.fetch(:print_results, false)
15
+ performance = evaluator.new
16
+ data.l.times do |i|
17
+ pred = Svm.svm_predict(self, data.x[i])
18
+ performance.add_result(data.y[i], pred)
19
+ if print_results
20
+ puts "Instance #{i}, Prediction: #{pred}, True label: #{data.y[i]}"
21
+ end
22
+ end
23
+ return performance
24
+ end
25
+
26
+ # Return the value of w squared for the hyperplane.
27
+ # -- returned as an array if there is not just one value.
28
+ def w_squared
29
+ if self.w_2.size == 1
30
+ self.w_2[0]
31
+ else
32
+ self.w_2.to_a
33
+ end
34
+ end
35
+
36
+ # Return an array of indices of the training instances used as
37
+ # support vectors.
38
+ def support_vector_indices
39
+ result = []
40
+ unless sv_indices.nil?
41
+ sv_indices.size.times do |i|
42
+ result << sv_indices[i]
43
+ end
44
+ end
45
+
46
+ return result
47
+ end
48
+
49
+ # Return the SVM problem type for this model
50
+ def svm_type
51
+ self.param.svm_type
52
+ end
53
+
54
+ # Return the kernel type for this model
55
+ def kernel_type
56
+ self.param.kernel_type
57
+ end
58
+
59
+ # Return the value of the degree parameter
60
+ def degree
61
+ self.param.degree
62
+ end
63
+
64
+ # Return the value of the gamma parameter
65
+ def gamma
66
+ self.param.gamma
67
+ end
68
+
69
+ # Return the value of the cost parameter
70
+ def cost
71
+ self.param.cost
72
+ end
73
+
74
+ # Return the number of classes handled by this model.
75
+ def number_classes
76
+ self.nr_class
77
+ end
78
+
79
+ # Save model to given filename.
80
+ # Raises IOError on any error.
81
+ def save filename
82
+ begin
83
+ Svm.svm_save_model(filename, self)
84
+ rescue java.io.IOException
85
+ raise IOError.new "Error in saving SVM model to file"
86
+ end
87
+ end
88
+
89
+ # Load model from given filename.
90
+ # Raises IOError on any error.
91
+ def self.load filename
92
+ begin
93
+ Svm.svm_load_model(filename)
94
+ rescue java.io.IOException
95
+ raise IOError.new "Error in loading SVM model from file"
96
+ end
97
+ end
98
+
99
+ #
100
+ # Predict the class of given instance number in given problem.
101
+ #
102
+ def predict(problem, instance_number)
103
+ Svm.svm_predict(self, problem.x[instance_number])
104
+ end
105
+
106
+ #
107
+ # Return the values of given instance number of given problem against
108
+ # each decision boundary.
109
+ # (This is the distance of the instance from each boundary.)
110
+ #
111
+ # Return value is an array if more than one decision boundary.
112
+ #
113
+ def predict_values(problem, instance_number)
114
+ dist = Array.new(number_classes*(number_classes-1)/2, 0).to_java(:double)
115
+ Svm.svm_predict_values(self, problem.x[instance_number], dist)
116
+ if dist.size == 1
117
+ return dist[0]
118
+ else
119
+ return dist.to_a
120
+ end
121
+ end
122
+ end
123
+ end
124
+
@@ -0,0 +1,21 @@
1
+ module SvmToolkit
2
+
3
+ # Extends the Java Node class.
4
+ #
5
+ # Node is used to store the index/value pair for an individual
6
+ # feature of an instance.
7
+ #
8
+ class Node
9
+
10
+ # Constructor:
11
+ # index:: Index of this node in feature set.
12
+ # value:: Value of this node in feature set.
13
+ def initialize(index, value)
14
+ super()
15
+ self.index = index
16
+ self.value = value
17
+ end
18
+ end
19
+
20
+ end
21
+
@@ -0,0 +1,117 @@
1
+ module SvmToolkit
2
+
3
+ # Extends the Java Parameter class with some additional methods.
4
+ #
5
+ # Parameter holds values determining the kernel type
6
+ # and training process.
7
+ #
8
+ # svm_type:: The type of SVM problem being solved.
9
+ # * C_SVC, the usual classification task.
10
+ # * NU_SVC
11
+ # * ONE_CLASS
12
+ # * EPSILON_SVR
13
+ # * NU_SVR
14
+ #
15
+ # kernel_type:: The type of kernel to use.
16
+ # * LINEAR
17
+ # * POLY
18
+ # * RBF
19
+ # * SIGMOID
20
+ # * PRECOMPUTED
21
+ #
22
+ # degree:: A parameter in polynomial kernels.
23
+ #
24
+ # gamma:: A parameter in poly/rbf/sigmoid kernels.
25
+ #
26
+ # coef0:: A parameter for poly/sigmoid kernels.
27
+ #
28
+ # cache_size:: For training, in MB.
29
+ #
30
+ # eps:: For training, stopping criterion.
31
+ #
32
+ # C:: For training with C_SVC, EPSILON_SVR, NU_SVR: the cost parameter.
33
+ #
34
+ # nr_weight:: For training with C_SVC.
35
+ #
36
+ # weight_label:: For training with C_SVC.
37
+ #
38
+ # weight:: For training with C_SVC.
39
+ #
40
+ # nu:: For training with NU_SVR, ONE_CLASS, NU_SVC.
41
+ #
42
+ # p:: For training with EPSILON_SVR.
43
+ #
44
+ # shrinking:: training, whether to use shrinking heuristics.
45
+ #
46
+ # probability:: For training, whether to use probability estimates.
47
+ #
48
+ class Parameter
49
+
50
+ # Constructor sets up values of attributes based on provided map.
51
+ # Valid keys with their default values:
52
+ # * :svm_type = Parameter::C_SVC, for the type of SVM
53
+ # * :kernel_type = Parameter::LINEAR, for the type of kernel
54
+ # * :cost = 1.0, for the cost or C parameter
55
+ # * :gamma = 0.0, for the gamma parameter in kernel
56
+ # * :degree = 1, for polynomial kernel
57
+ # * :coef0 = 0.0, for polynomial/sigmoid kernels
58
+ # * :eps = 0.001, for stopping criterion
59
+ # * :nr_weight = 0, for C_SVC
60
+ # * :nu = 0.5, used for NU_SVC, ONE_CLASS and NU_SVR. Nu must be in (0,1]
61
+ # * :p = 0.1, used for EPSILON_SVR
62
+ # * :shrinking = 1, use the shrinking heuristics
63
+ # * :probability = 0, use the probability estimates
64
+ def initialize args
65
+ super()
66
+ self.svm_type = args.fetch(:svm_type, Parameter::C_SVC)
67
+ self.kernel_type = args.fetch(:kernel_type, Parameter::LINEAR)
68
+ self.C = args.fetch(:cost, 1.0)
69
+ self.gamma = args.fetch(:gamma, 0.0)
70
+ self.degree = args.fetch(:degree, 1)
71
+ self.coef0 = args.fetch(:coef0, 0.0)
72
+ self.eps = args.fetch(:eps, 0.001)
73
+ self.nr_weight = args.fetch(:nr_weight, 0)
74
+ self.nu = args.fetch(:nu, 0.5)
75
+ self.p = args.fetch(:p, 0.1)
76
+ self.shrinking = args.fetch(:shrinking, 1)
77
+ self.probability = args.fetch(:probability, 0)
78
+
79
+ unless self.nu > 0.0 and self.nu <= 1.0
80
+ raise ArgumentError "Invalid value of nu #{self.nu}, should be in (0,1]"
81
+ end
82
+ end
83
+
84
+ # A more readable accessor for the C parameter
85
+ def cost
86
+ self.C
87
+ end
88
+
89
+ # A more readable mutator for the C parameter
90
+ def cost= val
91
+ self.C = val
92
+ end
93
+
94
+ # Return a list of the available kernels.
95
+ def self.kernels
96
+ [Parameter::LINEAR, Parameter::POLY, Parameter::RBF, Parameter::SIGMOID]
97
+ end
98
+
99
+ # Return a printable name for the given kernel.
100
+ def self.kernel_name kernel
101
+ case kernel
102
+ when Parameter::LINEAR
103
+ "Linear"
104
+ when Parameter::POLY
105
+ "Polynomial"
106
+ when Parameter::RBF
107
+ "Radial basis function"
108
+ when Parameter::SIGMOID
109
+ "Sigmoid"
110
+ else
111
+ "Unknown"
112
+ end
113
+ end
114
+ end
115
+
116
+ end
117
+
@@ -0,0 +1,308 @@
1
+ module SvmToolkit
2
+
3
+ # Extends the Java Problem class with some additional features.
4
+ #
5
+ class Problem
6
+
7
+ # Support constructing a problem from arrays of double values.
8
+ #
9
+ # instances:: an array of instances, each instance being an array of doubles.
10
+ # labels:: an array of doubles, forming the labels for each instance.
11
+ #
12
+ # An ArgumentError exception is raised if all the following conditions are not met:
13
+ # * the number of instances should equal the number of labels,
14
+ # * there must be at least one instance, and
15
+ # * every instance must have the same number of features.
16
+ #
17
+ def Problem.from_array(instances, labels)
18
+ unless instances.size == labels.size
19
+ raise ArgumentError.new "Number of instances must equal number of labels"
20
+ end
21
+ unless instances.size > 0
22
+ raise ArgumentError.new "There must be at least one instance."
23
+ end
24
+ unless instances.collect {|i| i.size}.min == instances.collect {|i| i.size}.max
25
+ raise ArgumentError.new "All instances must have the same size"
26
+ end
27
+
28
+ problem = Problem.new
29
+ problem.l = labels.size
30
+ # -- add in the training data
31
+ problem.x = Node[instances.size, instances[0].size].new
32
+ instances.each_with_index do |instance, i|
33
+ instance.each_with_index do |v, j|
34
+ problem.x[i][j] = Node.new(j, v)
35
+ end
36
+ end
37
+ # -- add in the labels
38
+ problem.y = Java::double[labels.size].new
39
+ labels.each_with_index do |v, i|
40
+ problem.y[i] = v
41
+ end
42
+
43
+ return problem
44
+ end
45
+
46
+ # To select SvmLight input file format
47
+ SvmLight = 0
48
+
49
+ # To select Csv input file format
50
+ Csv = 1
51
+
52
+ # To select ARFF input file format
53
+ Arff = 2
54
+
55
+ #
56
+ # Read in a problem definition from a file.
57
+ #
58
+ # filename:: the name of the file
59
+ # format:: either Svm::SvmLight (default), Svm::Csv or Svm::Arff
60
+ #
61
+ # Raises ArgumentError if there is any error in format.
62
+ #
63
+ def Problem.from_file(filename, format = SvmLight)
64
+ case format
65
+ when SvmLight
66
+ return Problem.from_file_svmlight filename
67
+ when Csv
68
+ return Problem.from_file_csv filename
69
+ when Arff
70
+ return Problem.from_file_arff filename
71
+ end
72
+ end
73
+
74
+ #
75
+ # Read in a problem definition in svmlight format.
76
+ #
77
+ # filename:: the name of the file
78
+ #
79
+ # Raises ArgumentError if there is any error in format.
80
+ #
81
+ def Problem.from_file_svmlight filename
82
+ instances = []
83
+ labels = []
84
+ max_index = 0
85
+ IO.foreach(filename) do |line|
86
+ tokens = line.split(" ")
87
+ labels << tokens[0].to_f
88
+ instance = []
89
+ tokens[1..-1].each do |feature|
90
+ index, value = feature.split(":")
91
+ instance << Node.new(index.to_i, value.to_f)
92
+ max_index = [index.to_i, max_index].max
93
+ end
94
+ instances << instance
95
+ end
96
+ max_index += 1 # to allow for 0 position
97
+ unless instances.size == labels.size
98
+ raise ArgumentError.new "Number of labels read differs from number of instances"
99
+ end
100
+ # now create a Problem definition
101
+ problem = Problem.new
102
+ problem.l = instances.size
103
+ # -- add in the training data
104
+ problem.x = Node[instances.size, max_index].new
105
+ # -- fill with blank nodes
106
+ instances.size.times do |i|
107
+ max_index.times do |j|
108
+ problem.x[i][j] = Node.new(i, 0)
109
+ end
110
+ end
111
+ # -- add known values
112
+ instances.each_with_index do |instance, i|
113
+ instance.each do |node|
114
+ problem.x[i][node.index] = node
115
+ end
116
+ end
117
+ # -- add in the labels
118
+ problem.y = Java::double[labels.size].new
119
+ labels.each_with_index do |v, i|
120
+ problem.y[i] = v
121
+ end
122
+
123
+ return problem
124
+ end
125
+
126
+ #
127
+ # Read in a problem definition in csv format.
128
+ #
129
+ # filename:: the name of the file
130
+ #
131
+ # Raises ArgumentError if there is any error in format.
132
+ #
133
+ def Problem.from_file_csv filename
134
+ instances = []
135
+ labels = []
136
+ max_index = 0
137
+ csv_data = CSV.parse(File.read(filename), headers: false)
138
+ csv_data.each do |tokens|
139
+ labels << tokens[0].to_f
140
+ instance = []
141
+ tokens[1..-1].each_with_index do |value, index|
142
+ instance << Node.new(index, value.to_f)
143
+ end
144
+ max_index = [tokens.size, max_index].max
145
+ instances << instance
146
+ end
147
+ max_index += 1 # to allow for 0 position
148
+ unless instances.size == labels.size
149
+ raise ArgumentError.new "Number of labels read differs from number of instances"
150
+ end
151
+ # now create a Problem definition
152
+ problem = Problem.new
153
+ problem.l = instances.size
154
+ # -- add in the training data
155
+ problem.x = Node[instances.size, max_index].new
156
+ # -- fill with blank nodes
157
+ instances.size.times do |i|
158
+ max_index.times do |j|
159
+ problem.x[i][j] = Node.new(i, 0)
160
+ end
161
+ end
162
+ # -- add known values
163
+ instances.each_with_index do |instance, i|
164
+ instance.each do |node|
165
+ problem.x[i][node.index] = node
166
+ end
167
+ end
168
+ # -- add in the labels
169
+ problem.y = Java::double[labels.size].new
170
+ labels.each_with_index do |v, i|
171
+ problem.y[i] = v
172
+ end
173
+
174
+ return problem
175
+ end
176
+
177
+ #
178
+ # Read in a problem definition in arff format.
179
+ # Assumes all values are numbers (non-numbers converted to 0.0),
180
+ # and that the class is the last field.
181
+ #
182
+ # filename:: the name of the file
183
+ #
184
+ # Raises ArgumentError if there is any error in format.
185
+ #
186
+ def Problem.from_file_arff filename
187
+ instances = []
188
+ labels = []
189
+ max_index = 0
190
+ found_data = false
191
+ IO.foreach(filename) do |line|
192
+ unless found_data
193
+ puts "Ignoring", line
194
+ found_data = line.downcase.strip == "@data"
195
+ next # repeat the loop
196
+ end
197
+ tokens = line.split(",")
198
+ labels << tokens.last.to_f
199
+ instance = []
200
+ tokens[1...-1].each_with_index do |value, index|
201
+ instance << Node.new(index, value.to_f)
202
+ end
203
+ max_index = [tokens.size, max_index].max
204
+ instances << instance
205
+ end
206
+ max_index += 1 # to allow for 0 position
207
+ unless instances.size == labels.size
208
+ raise ArgumentError.new "Number of labels read differs from number of instances"
209
+ end
210
+ # now create a Problem definition
211
+ problem = Problem.new
212
+ problem.l = instances.size
213
+ # -- add in the training data
214
+ problem.x = Node[instances.size, max_index].new
215
+ # -- fill with blank nodes
216
+ instances.size.times do |i|
217
+ max_index.times do |j|
218
+ problem.x[i][j] = Node.new(i, 0)
219
+ end
220
+ end
221
+ # -- add known values
222
+ instances.each_with_index do |instance, i|
223
+ instance.each do |node|
224
+ problem.x[i][node.index] = node
225
+ end
226
+ end
227
+ # -- add in the labels
228
+ problem.y = Java::double[labels.size].new
229
+ labels.each_with_index do |v, i|
230
+ problem.y[i] = v
231
+ end
232
+
233
+ return problem
234
+ end
235
+
236
+ # Returns the number of instances
237
+ def size
238
+ self.l
239
+ end
240
+
241
+ # Rescale values within problem to be in range min_value to max_value
242
+ #
243
+ # For SVM models, it is recommended all features be in range [0,1] or [-1,1]
244
+ def rescale(min_value = 0.0, max_value = 1.0)
245
+ return if self.l.zero?
246
+ x[0].size.times do |i|
247
+ rescale_column(i, min_value, max_value)
248
+ end
249
+ end
250
+
251
+ # Create a new problem by combining the instances in this problem with
252
+ # those in the given problem.
253
+ def merge problem
254
+ unless self.x[0].size == problem.x[0].size
255
+ raise ArgumentError.new "Cannot merge two problems with different numbers of features"
256
+ end
257
+ num_features = self.x[0].size
258
+ num_instances = size + problem.size
259
+
260
+ new_problem = Problem.new
261
+ new_problem.l = num_instances
262
+ new_problem.x = Node[num_instances, num_features].new
263
+ new_problem.y = Java::double[num_instances].new
264
+ # fill out the features
265
+ num_instances.times do |i|
266
+ num_features.times do |j|
267
+ if i < size
268
+ new_problem.x[i][j] = self.x[i][j]
269
+ else
270
+ new_problem.x[i][j] = problem.x[i-size][j]
271
+ end
272
+ end
273
+ end
274
+ # fill out the labels
275
+ num_instances.times do |i|
276
+ if i < size
277
+ new_problem.y[i] = self.y[i]
278
+ else
279
+ new_problem.y[i] = problem.y[i-size]
280
+ end
281
+ end
282
+
283
+ return new_problem
284
+ end
285
+
286
+ # Rescale values within problem for given column index,
287
+ # to be in range min_value to max_value
288
+ private
289
+ def rescale_column(col, min_value, max_value)
290
+ # -- first locate the column's range
291
+ current_min = x[0][col].value
292
+ current_max = x[0][col].value
293
+ self.l.times do |index|
294
+ if x[index][col].value < current_min
295
+ current_min = x[index][col].value
296
+ end
297
+ if x[index][col].value > current_max
298
+ current_max = x[index][col].value
299
+ end
300
+ end
301
+ # -- then update each value
302
+ self.l.times do |index|
303
+ x[index][col].value = ((max_value - min_value) * (x[index][col].value - current_min) / (current_max - current_min)) + min_value
304
+ end
305
+ end
306
+ end
307
+ end
308
+