svm_toolkit 1.1.7-java

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,124 @@
1
+ module SvmToolkit
2
+
3
+ # Extends the Java Model class with some additional methods.
4
+ #
5
+ class Model
6
+
7
+ # Evaluate model on given data set (an instance of Problem),
8
+ # returning the number of errors made.
9
+ # Optional parameters include:
10
+ # * :evaluator => Evaluator::OverallAccuracy, the name of the class to use for computing performance
11
+ # * :print_results => false, whether to print the result for each instance
12
+ def evaluate_dataset(data, params = {})
13
+ evaluator = params.fetch(:evaluator, Evaluator::OverallAccuracy)
14
+ print_results = params.fetch(:print_results, false)
15
+ performance = evaluator.new
16
+ data.l.times do |i|
17
+ pred = Svm.svm_predict(self, data.x[i])
18
+ performance.add_result(data.y[i], pred)
19
+ if print_results
20
+ puts "Instance #{i}, Prediction: #{pred}, True label: #{data.y[i]}"
21
+ end
22
+ end
23
+ return performance
24
+ end
25
+
26
+ # Return the value of w squared for the hyperplane.
27
+ # -- returned as an array if there is not just one value.
28
+ def w_squared
29
+ if self.w_2.size == 1
30
+ self.w_2[0]
31
+ else
32
+ self.w_2.to_a
33
+ end
34
+ end
35
+
36
+ # Return an array of indices of the training instances used as
37
+ # support vectors.
38
+ def support_vector_indices
39
+ result = []
40
+ unless sv_indices.nil?
41
+ sv_indices.size.times do |i|
42
+ result << sv_indices[i]
43
+ end
44
+ end
45
+
46
+ return result
47
+ end
48
+
49
+ # Return the SVM problem type for this model
50
+ def svm_type
51
+ self.param.svm_type
52
+ end
53
+
54
+ # Return the kernel type for this model
55
+ def kernel_type
56
+ self.param.kernel_type
57
+ end
58
+
59
+ # Return the value of the degree parameter
60
+ def degree
61
+ self.param.degree
62
+ end
63
+
64
+ # Return the value of the gamma parameter
65
+ def gamma
66
+ self.param.gamma
67
+ end
68
+
69
+ # Return the value of the cost parameter
70
+ def cost
71
+ self.param.cost
72
+ end
73
+
74
+ # Return the number of classes handled by this model.
75
+ def number_classes
76
+ self.nr_class
77
+ end
78
+
79
+ # Save model to given filename.
80
+ # Raises IOError on any error.
81
+ def save filename
82
+ begin
83
+ Svm.svm_save_model(filename, self)
84
+ rescue java.io.IOException
85
+ raise IOError.new "Error in saving SVM model to file"
86
+ end
87
+ end
88
+
89
+ # Load model from given filename.
90
+ # Raises IOError on any error.
91
+ def self.load filename
92
+ begin
93
+ Svm.svm_load_model(filename)
94
+ rescue java.io.IOException
95
+ raise IOError.new "Error in loading SVM model from file"
96
+ end
97
+ end
98
+
99
+ #
100
+ # Predict the class of given instance number in given problem.
101
+ #
102
+ def predict(problem, instance_number)
103
+ Svm.svm_predict(self, problem.x[instance_number])
104
+ end
105
+
106
+ #
107
+ # Return the values of given instance number of given problem against
108
+ # each decision boundary.
109
+ # (This is the distance of the instance from each boundary.)
110
+ #
111
+ # Return value is an array if more than one decision boundary.
112
+ #
113
+ def predict_values(problem, instance_number)
114
+ dist = Array.new(number_classes*(number_classes-1)/2, 0).to_java(:double)
115
+ Svm.svm_predict_values(self, problem.x[instance_number], dist)
116
+ if dist.size == 1
117
+ return dist[0]
118
+ else
119
+ return dist.to_a
120
+ end
121
+ end
122
+ end
123
+ end
124
+
@@ -0,0 +1,21 @@
1
+ module SvmToolkit
2
+
3
+ # Extends the Java Node class.
4
+ #
5
+ # Node is used to store the index/value pair for an individual
6
+ # feature of an instance.
7
+ #
8
+ class Node
9
+
10
+ # Constructor:
11
+ # index:: Index of this node in feature set.
12
+ # value:: Value of this node in feature set.
13
+ def initialize(index, value)
14
+ super()
15
+ self.index = index
16
+ self.value = value
17
+ end
18
+ end
19
+
20
+ end
21
+
@@ -0,0 +1,117 @@
1
+ module SvmToolkit
2
+
3
+ # Extends the Java Parameter class with some additional methods.
4
+ #
5
+ # Parameter holds values determining the kernel type
6
+ # and training process.
7
+ #
8
+ # svm_type:: The type of SVM problem being solved.
9
+ # * C_SVC, the usual classification task.
10
+ # * NU_SVC
11
+ # * ONE_CLASS
12
+ # * EPSILON_SVR
13
+ # * NU_SVR
14
+ #
15
+ # kernel_type:: The type of kernel to use.
16
+ # * LINEAR
17
+ # * POLY
18
+ # * RBF
19
+ # * SIGMOID
20
+ # * PRECOMPUTED
21
+ #
22
+ # degree:: A parameter in polynomial kernels.
23
+ #
24
+ # gamma:: A parameter in poly/rbf/sigmoid kernels.
25
+ #
26
+ # coef0:: A parameter for poly/sigmoid kernels.
27
+ #
28
+ # cache_size:: For training, in MB.
29
+ #
30
+ # eps:: For training, stopping criterion.
31
+ #
32
+ # C:: For training with C_SVC, EPSILON_SVR, NU_SVR: the cost parameter.
33
+ #
34
+ # nr_weight:: For training with C_SVC.
35
+ #
36
+ # weight_label:: For training with C_SVC.
37
+ #
38
+ # weight:: For training with C_SVC.
39
+ #
40
+ # nu:: For training with NU_SVR, ONE_CLASS, NU_SVC.
41
+ #
42
+ # p:: For training with EPSILON_SVR.
43
+ #
44
+ # shrinking:: training, whether to use shrinking heuristics.
45
+ #
46
+ # probability:: For training, whether to use probability estimates.
47
+ #
48
+ class Parameter
49
+
50
+ # Constructor sets up values of attributes based on provided map.
51
+ # Valid keys with their default values:
52
+ # * :svm_type = Parameter::C_SVC, for the type of SVM
53
+ # * :kernel_type = Parameter::LINEAR, for the type of kernel
54
+ # * :cost = 1.0, for the cost or C parameter
55
+ # * :gamma = 0.0, for the gamma parameter in kernel
56
+ # * :degree = 1, for polynomial kernel
57
+ # * :coef0 = 0.0, for polynomial/sigmoid kernels
58
+ # * :eps = 0.001, for stopping criterion
59
+ # * :nr_weight = 0, for C_SVC
60
+ # * :nu = 0.5, used for NU_SVC, ONE_CLASS and NU_SVR. Nu must be in (0,1]
61
+ # * :p = 0.1, used for EPSILON_SVR
62
+ # * :shrinking = 1, use the shrinking heuristics
63
+ # * :probability = 0, use the probability estimates
64
+ def initialize args
65
+ super()
66
+ self.svm_type = args.fetch(:svm_type, Parameter::C_SVC)
67
+ self.kernel_type = args.fetch(:kernel_type, Parameter::LINEAR)
68
+ self.C = args.fetch(:cost, 1.0)
69
+ self.gamma = args.fetch(:gamma, 0.0)
70
+ self.degree = args.fetch(:degree, 1)
71
+ self.coef0 = args.fetch(:coef0, 0.0)
72
+ self.eps = args.fetch(:eps, 0.001)
73
+ self.nr_weight = args.fetch(:nr_weight, 0)
74
+ self.nu = args.fetch(:nu, 0.5)
75
+ self.p = args.fetch(:p, 0.1)
76
+ self.shrinking = args.fetch(:shrinking, 1)
77
+ self.probability = args.fetch(:probability, 0)
78
+
79
+ unless self.nu > 0.0 and self.nu <= 1.0
80
+ raise ArgumentError "Invalid value of nu #{self.nu}, should be in (0,1]"
81
+ end
82
+ end
83
+
84
+ # A more readable accessor for the C parameter
85
+ def cost
86
+ self.C
87
+ end
88
+
89
+ # A more readable mutator for the C parameter
90
+ def cost= val
91
+ self.C = val
92
+ end
93
+
94
+ # Return a list of the available kernels.
95
+ def self.kernels
96
+ [Parameter::LINEAR, Parameter::POLY, Parameter::RBF, Parameter::SIGMOID]
97
+ end
98
+
99
+ # Return a printable name for the given kernel.
100
+ def self.kernel_name kernel
101
+ case kernel
102
+ when Parameter::LINEAR
103
+ "Linear"
104
+ when Parameter::POLY
105
+ "Polynomial"
106
+ when Parameter::RBF
107
+ "Radial basis function"
108
+ when Parameter::SIGMOID
109
+ "Sigmoid"
110
+ else
111
+ "Unknown"
112
+ end
113
+ end
114
+ end
115
+
116
+ end
117
+
@@ -0,0 +1,308 @@
1
+ module SvmToolkit
2
+
3
+ # Extends the Java Problem class with some additional features.
4
+ #
5
+ class Problem
6
+
7
+ # Support constructing a problem from arrays of double values.
8
+ #
9
+ # instances:: an array of instances, each instance being an array of doubles.
10
+ # labels:: an array of doubles, forming the labels for each instance.
11
+ #
12
+ # An ArgumentError exception is raised if all the following conditions are not met:
13
+ # * the number of instances should equal the number of labels,
14
+ # * there must be at least one instance, and
15
+ # * every instance must have the same number of features.
16
+ #
17
+ def Problem.from_array(instances, labels)
18
+ unless instances.size == labels.size
19
+ raise ArgumentError.new "Number of instances must equal number of labels"
20
+ end
21
+ unless instances.size > 0
22
+ raise ArgumentError.new "There must be at least one instance."
23
+ end
24
+ unless instances.collect {|i| i.size}.min == instances.collect {|i| i.size}.max
25
+ raise ArgumentError.new "All instances must have the same size"
26
+ end
27
+
28
+ problem = Problem.new
29
+ problem.l = labels.size
30
+ # -- add in the training data
31
+ problem.x = Node[instances.size, instances[0].size].new
32
+ instances.each_with_index do |instance, i|
33
+ instance.each_with_index do |v, j|
34
+ problem.x[i][j] = Node.new(j, v)
35
+ end
36
+ end
37
+ # -- add in the labels
38
+ problem.y = Java::double[labels.size].new
39
+ labels.each_with_index do |v, i|
40
+ problem.y[i] = v
41
+ end
42
+
43
+ return problem
44
+ end
45
+
46
+ # To select SvmLight input file format
47
+ SvmLight = 0
48
+
49
+ # To select Csv input file format
50
+ Csv = 1
51
+
52
+ # To select ARFF input file format
53
+ Arff = 2
54
+
55
+ #
56
+ # Read in a problem definition from a file.
57
+ #
58
+ # filename:: the name of the file
59
+ # format:: either Svm::SvmLight (default), Svm::Csv or Svm::Arff
60
+ #
61
+ # Raises ArgumentError if there is any error in format.
62
+ #
63
+ def Problem.from_file(filename, format = SvmLight)
64
+ case format
65
+ when SvmLight
66
+ return Problem.from_file_svmlight filename
67
+ when Csv
68
+ return Problem.from_file_csv filename
69
+ when Arff
70
+ return Problem.from_file_arff filename
71
+ end
72
+ end
73
+
74
+ #
75
+ # Read in a problem definition in svmlight format.
76
+ #
77
+ # filename:: the name of the file
78
+ #
79
+ # Raises ArgumentError if there is any error in format.
80
+ #
81
+ def Problem.from_file_svmlight filename
82
+ instances = []
83
+ labels = []
84
+ max_index = 0
85
+ IO.foreach(filename) do |line|
86
+ tokens = line.split(" ")
87
+ labels << tokens[0].to_f
88
+ instance = []
89
+ tokens[1..-1].each do |feature|
90
+ index, value = feature.split(":")
91
+ instance << Node.new(index.to_i, value.to_f)
92
+ max_index = [index.to_i, max_index].max
93
+ end
94
+ instances << instance
95
+ end
96
+ max_index += 1 # to allow for 0 position
97
+ unless instances.size == labels.size
98
+ raise ArgumentError.new "Number of labels read differs from number of instances"
99
+ end
100
+ # now create a Problem definition
101
+ problem = Problem.new
102
+ problem.l = instances.size
103
+ # -- add in the training data
104
+ problem.x = Node[instances.size, max_index].new
105
+ # -- fill with blank nodes
106
+ instances.size.times do |i|
107
+ max_index.times do |j|
108
+ problem.x[i][j] = Node.new(i, 0)
109
+ end
110
+ end
111
+ # -- add known values
112
+ instances.each_with_index do |instance, i|
113
+ instance.each do |node|
114
+ problem.x[i][node.index] = node
115
+ end
116
+ end
117
+ # -- add in the labels
118
+ problem.y = Java::double[labels.size].new
119
+ labels.each_with_index do |v, i|
120
+ problem.y[i] = v
121
+ end
122
+
123
+ return problem
124
+ end
125
+
126
+ #
127
+ # Read in a problem definition in csv format.
128
+ #
129
+ # filename:: the name of the file
130
+ #
131
+ # Raises ArgumentError if there is any error in format.
132
+ #
133
+ def Problem.from_file_csv filename
134
+ instances = []
135
+ labels = []
136
+ max_index = 0
137
+ csv_data = CSV.parse(File.read(filename), headers: false)
138
+ csv_data.each do |tokens|
139
+ labels << tokens[0].to_f
140
+ instance = []
141
+ tokens[1..-1].each_with_index do |value, index|
142
+ instance << Node.new(index, value.to_f)
143
+ end
144
+ max_index = [tokens.size, max_index].max
145
+ instances << instance
146
+ end
147
+ max_index += 1 # to allow for 0 position
148
+ unless instances.size == labels.size
149
+ raise ArgumentError.new "Number of labels read differs from number of instances"
150
+ end
151
+ # now create a Problem definition
152
+ problem = Problem.new
153
+ problem.l = instances.size
154
+ # -- add in the training data
155
+ problem.x = Node[instances.size, max_index].new
156
+ # -- fill with blank nodes
157
+ instances.size.times do |i|
158
+ max_index.times do |j|
159
+ problem.x[i][j] = Node.new(i, 0)
160
+ end
161
+ end
162
+ # -- add known values
163
+ instances.each_with_index do |instance, i|
164
+ instance.each do |node|
165
+ problem.x[i][node.index] = node
166
+ end
167
+ end
168
+ # -- add in the labels
169
+ problem.y = Java::double[labels.size].new
170
+ labels.each_with_index do |v, i|
171
+ problem.y[i] = v
172
+ end
173
+
174
+ return problem
175
+ end
176
+
177
+ #
178
+ # Read in a problem definition in arff format.
179
+ # Assumes all values are numbers (non-numbers converted to 0.0),
180
+ # and that the class is the last field.
181
+ #
182
+ # filename:: the name of the file
183
+ #
184
+ # Raises ArgumentError if there is any error in format.
185
+ #
186
+ def Problem.from_file_arff filename
187
+ instances = []
188
+ labels = []
189
+ max_index = 0
190
+ found_data = false
191
+ IO.foreach(filename) do |line|
192
+ unless found_data
193
+ puts "Ignoring", line
194
+ found_data = line.downcase.strip == "@data"
195
+ next # repeat the loop
196
+ end
197
+ tokens = line.split(",")
198
+ labels << tokens.last.to_f
199
+ instance = []
200
+ tokens[1...-1].each_with_index do |value, index|
201
+ instance << Node.new(index, value.to_f)
202
+ end
203
+ max_index = [tokens.size, max_index].max
204
+ instances << instance
205
+ end
206
+ max_index += 1 # to allow for 0 position
207
+ unless instances.size == labels.size
208
+ raise ArgumentError.new "Number of labels read differs from number of instances"
209
+ end
210
+ # now create a Problem definition
211
+ problem = Problem.new
212
+ problem.l = instances.size
213
+ # -- add in the training data
214
+ problem.x = Node[instances.size, max_index].new
215
+ # -- fill with blank nodes
216
+ instances.size.times do |i|
217
+ max_index.times do |j|
218
+ problem.x[i][j] = Node.new(i, 0)
219
+ end
220
+ end
221
+ # -- add known values
222
+ instances.each_with_index do |instance, i|
223
+ instance.each do |node|
224
+ problem.x[i][node.index] = node
225
+ end
226
+ end
227
+ # -- add in the labels
228
+ problem.y = Java::double[labels.size].new
229
+ labels.each_with_index do |v, i|
230
+ problem.y[i] = v
231
+ end
232
+
233
+ return problem
234
+ end
235
+
236
+ # Returns the number of instances
237
+ def size
238
+ self.l
239
+ end
240
+
241
+ # Rescale values within problem to be in range min_value to max_value
242
+ #
243
+ # For SVM models, it is recommended all features be in range [0,1] or [-1,1]
244
+ def rescale(min_value = 0.0, max_value = 1.0)
245
+ return if self.l.zero?
246
+ x[0].size.times do |i|
247
+ rescale_column(i, min_value, max_value)
248
+ end
249
+ end
250
+
251
+ # Create a new problem by combining the instances in this problem with
252
+ # those in the given problem.
253
+ def merge problem
254
+ unless self.x[0].size == problem.x[0].size
255
+ raise ArgumentError.new "Cannot merge two problems with different numbers of features"
256
+ end
257
+ num_features = self.x[0].size
258
+ num_instances = size + problem.size
259
+
260
+ new_problem = Problem.new
261
+ new_problem.l = num_instances
262
+ new_problem.x = Node[num_instances, num_features].new
263
+ new_problem.y = Java::double[num_instances].new
264
+ # fill out the features
265
+ num_instances.times do |i|
266
+ num_features.times do |j|
267
+ if i < size
268
+ new_problem.x[i][j] = self.x[i][j]
269
+ else
270
+ new_problem.x[i][j] = problem.x[i-size][j]
271
+ end
272
+ end
273
+ end
274
+ # fill out the labels
275
+ num_instances.times do |i|
276
+ if i < size
277
+ new_problem.y[i] = self.y[i]
278
+ else
279
+ new_problem.y[i] = problem.y[i-size]
280
+ end
281
+ end
282
+
283
+ return new_problem
284
+ end
285
+
286
+ # Rescale values within problem for given column index,
287
+ # to be in range min_value to max_value
288
+ private
289
+ def rescale_column(col, min_value, max_value)
290
+ # -- first locate the column's range
291
+ current_min = x[0][col].value
292
+ current_max = x[0][col].value
293
+ self.l.times do |index|
294
+ if x[index][col].value < current_min
295
+ current_min = x[index][col].value
296
+ end
297
+ if x[index][col].value > current_max
298
+ current_max = x[index][col].value
299
+ end
300
+ end
301
+ # -- then update each value
302
+ self.l.times do |index|
303
+ x[index][col].value = ((max_value - min_value) * (x[index][col].value - current_min) / (current_max - current_min)) + min_value
304
+ end
305
+ end
306
+ end
307
+ end
308
+