eluka 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (73) hide show
  1. data/.document +5 -0
  2. data/DOCUMENTATION_STANDARDS +39 -0
  3. data/Gemfile +13 -0
  4. data/Gemfile.lock +20 -0
  5. data/LICENSE.txt +20 -0
  6. data/README.rdoc +19 -0
  7. data/Rakefile +69 -0
  8. data/VERSION +1 -0
  9. data/examples/example.rb +59 -0
  10. data/ext/libsvm/COPYRIGHT +31 -0
  11. data/ext/libsvm/FAQ.html +1749 -0
  12. data/ext/libsvm/Makefile +25 -0
  13. data/ext/libsvm/Makefile.win +33 -0
  14. data/ext/libsvm/README +733 -0
  15. data/ext/libsvm/extconf.rb +1 -0
  16. data/ext/libsvm/heart_scale +270 -0
  17. data/ext/libsvm/java/Makefile +25 -0
  18. data/ext/libsvm/java/libsvm.jar +0 -0
  19. data/ext/libsvm/java/libsvm/svm.java +2776 -0
  20. data/ext/libsvm/java/libsvm/svm.m4 +2776 -0
  21. data/ext/libsvm/java/libsvm/svm_model.java +21 -0
  22. data/ext/libsvm/java/libsvm/svm_node.java +6 -0
  23. data/ext/libsvm/java/libsvm/svm_parameter.java +47 -0
  24. data/ext/libsvm/java/libsvm/svm_print_interface.java +5 -0
  25. data/ext/libsvm/java/libsvm/svm_problem.java +7 -0
  26. data/ext/libsvm/java/svm_predict.java +163 -0
  27. data/ext/libsvm/java/svm_scale.java +350 -0
  28. data/ext/libsvm/java/svm_toy.java +471 -0
  29. data/ext/libsvm/java/svm_train.java +318 -0
  30. data/ext/libsvm/java/test_applet.html +1 -0
  31. data/ext/libsvm/python/Makefile +4 -0
  32. data/ext/libsvm/python/README +331 -0
  33. data/ext/libsvm/python/svm.py +259 -0
  34. data/ext/libsvm/python/svmutil.py +242 -0
  35. data/ext/libsvm/svm-predict.c +226 -0
  36. data/ext/libsvm/svm-scale.c +353 -0
  37. data/ext/libsvm/svm-toy/gtk/Makefile +22 -0
  38. data/ext/libsvm/svm-toy/gtk/callbacks.cpp +423 -0
  39. data/ext/libsvm/svm-toy/gtk/callbacks.h +54 -0
  40. data/ext/libsvm/svm-toy/gtk/interface.c +164 -0
  41. data/ext/libsvm/svm-toy/gtk/interface.h +14 -0
  42. data/ext/libsvm/svm-toy/gtk/main.c +23 -0
  43. data/ext/libsvm/svm-toy/gtk/svm-toy.glade +238 -0
  44. data/ext/libsvm/svm-toy/qt/Makefile +17 -0
  45. data/ext/libsvm/svm-toy/qt/svm-toy.cpp +413 -0
  46. data/ext/libsvm/svm-toy/windows/svm-toy.cpp +456 -0
  47. data/ext/libsvm/svm-train.c +376 -0
  48. data/ext/libsvm/svm.cpp +3060 -0
  49. data/ext/libsvm/svm.def +19 -0
  50. data/ext/libsvm/svm.h +105 -0
  51. data/ext/libsvm/svm.o +0 -0
  52. data/ext/libsvm/tools/README +149 -0
  53. data/ext/libsvm/tools/checkdata.py +108 -0
  54. data/ext/libsvm/tools/easy.py +79 -0
  55. data/ext/libsvm/tools/grid.py +359 -0
  56. data/ext/libsvm/tools/subset.py +146 -0
  57. data/ext/libsvm/windows/libsvm.dll +0 -0
  58. data/ext/libsvm/windows/svm-predict.exe +0 -0
  59. data/ext/libsvm/windows/svm-scale.exe +0 -0
  60. data/ext/libsvm/windows/svm-toy.exe +0 -0
  61. data/ext/libsvm/windows/svm-train.exe +0 -0
  62. data/lib/eluka.rb +10 -0
  63. data/lib/eluka/bijection.rb +23 -0
  64. data/lib/eluka/data_point.rb +36 -0
  65. data/lib/eluka/document.rb +47 -0
  66. data/lib/eluka/feature_vector.rb +86 -0
  67. data/lib/eluka/features.rb +31 -0
  68. data/lib/eluka/model.rb +129 -0
  69. data/lib/fselect.rb +321 -0
  70. data/lib/grid.rb +25 -0
  71. data/test/helper.rb +18 -0
  72. data/test/test_eluka.rb +7 -0
  73. metadata +214 -0
@@ -0,0 +1,259 @@
1
+ #!/usr/bin/env python
2
+
3
+ from ctypes import *
4
+ from ctypes.util import find_library
5
+ import sys
6
+
7
+ # For unix the prefix 'lib' is not considered.
8
+ if find_library('svm'):
9
+ libsvm = CDLL(find_library('svm'))
10
+ elif find_library('libsvm'):
11
+ libsvm = CDLL(find_library('libsvm'))
12
+ else:
13
+ if sys.platform == 'win32':
14
+ libsvm = CDLL('../windows/libsvm.dll')
15
+ else:
16
+ libsvm = CDLL('../libsvm.so.2')
17
+
18
+ # Construct constants
19
+ SVM_TYPE = ['C_SVC', 'NU_SVC', 'ONE_CLASS', 'EPSILON_SVR', 'NU_SVR' ]
20
+ KERNEL_TYPE = ['LINEAR', 'POLY', 'RBF', 'SIGMOID', 'PRECOMPUTED']
21
+ for i, s in enumerate(SVM_TYPE): exec("%s = %d" % (s , i))
22
+ for i, s in enumerate(KERNEL_TYPE): exec("%s = %d" % (s , i))
23
+
24
+ PRINT_STRING_FUN = CFUNCTYPE(None, c_char_p)
25
+ def print_null(s):
26
+ return
27
+
28
+ def genFields(names, types):
29
+ return list(zip(names, types))
30
+
31
+ def fillprototype(f, restype, argtypes):
32
+ f.restype = restype
33
+ f.argtypes = argtypes
34
+
35
+ class svm_node(Structure):
36
+ _names = ["index", "value"]
37
+ _types = [c_int, c_double]
38
+ _fields_ = genFields(_names, _types)
39
+
40
+ def gen_svm_nodearray(xi, feature_max=None, issparse=None):
41
+ if isinstance(xi, dict):
42
+ index_range = xi.keys()
43
+ elif isinstance(xi, (list, tuple)):
44
+ index_range = range(len(xi))
45
+ else:
46
+ raise TypeError('xi should be a dictionary, list or tuple')
47
+
48
+ if feature_max:
49
+ assert(isinstance(feature_max, int))
50
+ index_range = filter(lambda j: j <= feature_max, index_range)
51
+ if issparse:
52
+ index_range = filter(lambda j:xi[j] != 0, index_range)
53
+
54
+ index_range = sorted(index_range)
55
+ ret = (svm_node * (len(index_range)+1))()
56
+ ret[-1].index = -1
57
+ for idx, j in enumerate(index_range):
58
+ ret[idx].index = j
59
+ ret[idx].value = xi[j]
60
+ max_idx = 0
61
+ if index_range:
62
+ max_idx = index_range[-1]
63
+ return ret, max_idx
64
+
65
+ class svm_problem(Structure):
66
+ _names = ["l", "y", "x"]
67
+ _types = [c_int, POINTER(c_double), POINTER(POINTER(svm_node))]
68
+ _fields_ = genFields(_names, _types)
69
+
70
+ def __init__(self, y, x):
71
+ if len(y) != len(x):
72
+ raise ValueError("len(y) != len(x)")
73
+ self.l = l = len(y)
74
+
75
+ max_idx = 0
76
+ x_space = self.x_space = []
77
+ for i, xi in enumerate(x):
78
+ tmp_xi, tmp_idx = gen_svm_nodearray(xi)
79
+ x_space += [tmp_xi]
80
+ max_idx = max(max_idx, tmp_idx)
81
+ self.n = max_idx
82
+
83
+ self.y = (c_double * l)()
84
+ for i, yi in enumerate(y): self.y[i] = yi
85
+
86
+ self.x = (POINTER(svm_node) * l)()
87
+ for i, xi in enumerate(self.x_space): self.x[i] = xi
88
+
89
+ class svm_parameter(Structure):
90
+ _names = ["svm_type", "kernel_type", "degree", "gamma", "coef0",
91
+ "cache_size", "eps", "C", "nr_weight", "weight_label", "weight",
92
+ "nu", "p", "shrinking", "probability"]
93
+ _types = [c_int, c_int, c_int, c_double, c_double,
94
+ c_double, c_double, c_double, c_int, POINTER(c_int), POINTER(c_double),
95
+ c_double, c_double, c_int, c_int]
96
+ _fields_ = genFields(_names, _types)
97
+
98
+ def __init__(self, options = None):
99
+ if options == None:
100
+ options = ''
101
+ self.parse_options(options)
102
+
103
+ def show(self):
104
+ attrs = svm_parameter._names + self.__dict__.keys()
105
+ values = map(lambda attr: getattr(self, attr), attrs)
106
+ for attr, val in zip(attrs, values):
107
+ print(' %s: %s' % (attr, val))
108
+
109
+ def set_to_default_values(self):
110
+ self.svm_type = C_SVC;
111
+ self.kernel_type = RBF
112
+ self.degree = 3
113
+ self.gamma = 0
114
+ self.coef0 = 0
115
+ self.nu = 0.5
116
+ self.cache_size = 100
117
+ self.C = 1
118
+ self.eps = 0.001
119
+ self.p = 0.1
120
+ self.shrinking = 1
121
+ self.probability = 0
122
+ self.nr_weight = 0
123
+ self.weight_label = (c_int*0)()
124
+ self.weight = (c_double*0)()
125
+ self.cross_validation = False
126
+ self.nr_fold = 0
127
+ self.print_func = None
128
+
129
+ def parse_options(self, options):
130
+ argv = options.split()
131
+ self.set_to_default_values()
132
+ self.print_func = cast(None, PRINT_STRING_FUN)
133
+ weight_label = []
134
+ weight = []
135
+
136
+ i = 0
137
+ while i < len(argv):
138
+ if argv[i] == "-s":
139
+ i = i + 1
140
+ self.svm_type = int(argv[i])
141
+ elif argv[i] == "-t":
142
+ i = i + 1
143
+ self.kernel_type = int(argv[i])
144
+ elif argv[i] == "-d":
145
+ i = i + 1
146
+ self.degree = int(argv[i])
147
+ elif argv[i] == "-g":
148
+ i = i + 1
149
+ self.gamma = float(argv[i])
150
+ elif argv[i] == "-r":
151
+ i = i + 1
152
+ self.coef0 = float(argv[i])
153
+ elif argv[i] == "-n":
154
+ i = i + 1
155
+ self.nu = float(argv[i])
156
+ elif argv[i] == "-m":
157
+ i = i + 1
158
+ self.cache_size = float(argv[i])
159
+ elif argv[i] == "-c":
160
+ i = i + 1
161
+ self.C = float(argv[i])
162
+ elif argv[i] == "-e":
163
+ i = i + 1
164
+ self.eps = float(argv[i])
165
+ elif argv[i] == "-p":
166
+ i = i + 1
167
+ self.p = float(argv[i])
168
+ elif argv[i] == "-h":
169
+ i = i + 1
170
+ self.shrinking = int(argv[i])
171
+ elif argv[i] == "-b":
172
+ i = i + 1
173
+ self.probability = int(argv[i])
174
+ elif argv[i] == "-q":
175
+ self.print_func = PRINT_STRING_FUN(print_null)
176
+ elif argv[i] == "-v":
177
+ i = i + 1
178
+ self.cross_validation = 1
179
+ self.nr_fold = int(argv[i])
180
+ if self.nr_fold < 2:
181
+ raise ValueError("n-fold cross validation: n must >= 2")
182
+ elif argv[i].startswith("-w"):
183
+ i = i + 1
184
+ self.nr_weight += 1
185
+ nr_weight = self.nr_weight
186
+ weight_label += [int(argv[i-1][2:])]
187
+ weight += [float(argv[i])]
188
+ else:
189
+ raise ValueError("Wrong options")
190
+ i += 1
191
+
192
+ libsvm.svm_set_print_string_function(self.print_func)
193
+ self.weight_label = (c_int*self.nr_weight)()
194
+ self.weight = (c_double*self.nr_weight)()
195
+ for i in range(self.nr_weight):
196
+ self.weight[i] = weight[i]
197
+ self.weight_label[i] = weight_label[i]
198
+
199
+ class svm_model(Structure):
200
+ def __init__(self):
201
+ self.__createfrom__ = 'python'
202
+
203
+ def __del__(self):
204
+ # free memory created by C to avoid memory leak
205
+ if hasattr(self, '__createfrom__') and self.__createfrom__ == 'C':
206
+ libsvm.svm_free_and_destroy_model(pointer(self))
207
+
208
+ def get_svm_type(self):
209
+ return libsvm.svm_get_svm_type(self)
210
+
211
+ def get_nr_class(self):
212
+ return libsvm.svm_get_nr_class(self)
213
+
214
+ def get_svr_probability(self):
215
+ return libsvm.svm_get_svr_probability(self)
216
+
217
+ def get_labels(self):
218
+ nr_class = self.get_nr_class()
219
+ labels = (c_int * nr_class)()
220
+ libsvm.svm_get_labels(self, labels)
221
+ return labels[:nr_class]
222
+
223
+ def is_probability_model(self):
224
+ return (libsvm.svm_check_probability_model(self) == 1)
225
+
226
+ def toPyModel(model_ptr):
227
+ """
228
+ toPyModel(model_ptr) -> svm_model
229
+
230
+ Convert a ctypes POINTER(svm_model) to a Python svm_model
231
+ """
232
+ if bool(model_ptr) == False:
233
+ raise ValueError("Null pointer")
234
+ m = model_ptr.contents
235
+ m.__createfrom__ = 'C'
236
+ return m
237
+
238
+ fillprototype(libsvm.svm_train, POINTER(svm_model), [POINTER(svm_problem), POINTER(svm_parameter)])
239
+ fillprototype(libsvm.svm_cross_validation, None, [POINTER(svm_problem), POINTER(svm_parameter), c_int, POINTER(c_double)])
240
+
241
+ fillprototype(libsvm.svm_save_model, c_int, [c_char_p, POINTER(svm_model)])
242
+ fillprototype(libsvm.svm_load_model, POINTER(svm_model), [c_char_p])
243
+
244
+ fillprototype(libsvm.svm_get_svm_type, c_int, [POINTER(svm_model)])
245
+ fillprototype(libsvm.svm_get_nr_class, c_int, [POINTER(svm_model)])
246
+ fillprototype(libsvm.svm_get_labels, None, [POINTER(svm_model), POINTER(c_int)])
247
+ fillprototype(libsvm.svm_get_svr_probability, c_double, [POINTER(svm_model)])
248
+
249
+ fillprototype(libsvm.svm_predict_values, c_double, [POINTER(svm_model), POINTER(svm_node), POINTER(c_double)])
250
+ fillprototype(libsvm.svm_predict, c_double, [POINTER(svm_model), POINTER(svm_node)])
251
+ fillprototype(libsvm.svm_predict_probability, c_double, [POINTER(svm_model), POINTER(svm_node), POINTER(c_double)])
252
+
253
+ fillprototype(libsvm.svm_free_model_content, None, [POINTER(svm_model)])
254
+ fillprototype(libsvm.svm_free_and_destroy_model, None, [POINTER(POINTER(svm_model))])
255
+ fillprototype(libsvm.svm_destroy_param, None, [POINTER(svm_parameter)])
256
+
257
+ fillprototype(libsvm.svm_check_parameter, c_char_p, [POINTER(svm_problem), POINTER(svm_parameter)])
258
+ fillprototype(libsvm.svm_check_probability_model, c_int, [POINTER(svm_model)])
259
+ fillprototype(libsvm.svm_set_print_string_function, None, [PRINT_STRING_FUN])
@@ -0,0 +1,242 @@
1
+ #!/usr/bin/env python
2
+
3
+ from svm import *
4
+
5
+ def svm_read_problem(data_file_name):
6
+ """
7
+ svm_read_problem(data_file_name) -> [y, x]
8
+
9
+ Read LIBSVM-format data from data_file_name and return labels y
10
+ and data instances x.
11
+ """
12
+ prob_y = []
13
+ prob_x = []
14
+ for line in open(data_file_name):
15
+ line = line.split(None, 1)
16
+ # In case an instance with all zero features
17
+ if len(line) == 1: line += ['']
18
+ label, features = line
19
+ xi = {}
20
+ for e in features.split():
21
+ ind, val = e.split(":")
22
+ xi[int(ind)] = float(val)
23
+ prob_y += [float(label)]
24
+ prob_x += [xi]
25
+ return (prob_y, prob_x)
26
+
27
+ def svm_load_model(model_file_name):
28
+ """
29
+ svm_load_model(model_file_name) -> model
30
+
31
+ Load a LIBSVM model from model_file_name and return.
32
+ """
33
+ model = libsvm.svm_load_model(model_file_name)
34
+ if not model:
35
+ print("can't open model file %s" % model_file_name)
36
+ return None
37
+ model = toPyModel(model)
38
+ return model
39
+
40
+ def svm_save_model(model_file_name, model):
41
+ """
42
+ svm_save_model(model_file_name, model) -> None
43
+
44
+ Save a LIBSVM model to the file model_file_name.
45
+ """
46
+ libsvm.svm_save_model(model_file_name, model)
47
+
48
+ def evaluations(ty, pv):
49
+ """
50
+ evaluations(ty, pv) -> (ACC, MSE, SCC)
51
+
52
+ Calculate accuracy, mean squared error and squared correlation coefficient
53
+ using the true values (ty) and predicted values (pv).
54
+ """
55
+ if len(ty) != len(pv):
56
+ raise ValueError("len(ty) must equal to len(pv)")
57
+ total_correct = total_error = 0
58
+ sumv = sumy = sumvv = sumyy = sumvy = 0
59
+ for v, y in zip(pv, ty):
60
+ if y == v:
61
+ total_correct += 1
62
+ total_error += (v-y)*(v-y)
63
+ sumv += v
64
+ sumy += y
65
+ sumvv += v*v
66
+ sumyy += y*y
67
+ sumvy += v*y
68
+ l = len(ty)
69
+ ACC = 100.0*total_correct/l
70
+ MSE = total_error/l
71
+ try:
72
+ SCC = ((l*sumvy-sumv*sumy)*(l*sumvy-sumv*sumy))/((l*sumvv-sumv*sumv)*(l*sumyy-sumy*sumy))
73
+ except:
74
+ SCC = float('nan')
75
+ return (ACC, MSE, SCC)
76
+
77
+ def svm_train(arg1, arg2=None, arg3=None):
78
+ """
79
+ svm_train(y, x [, 'options']) -> model | ACC | MSE
80
+ svm_train(prob, [, 'options']) -> model | ACC | MSE
81
+ svm_train(prob, param) -> model | ACC| MSE
82
+
83
+ Train an SVM model from data (y, x) or an svm_problem prob using
84
+ 'options' or an svm_parameter param.
85
+ If '-v' is specified in 'options' (i.e., cross validation)
86
+ either accuracy (ACC) or mean-squared error (MSE) is returned.
87
+ 'options':
88
+ -s svm_type : set type of SVM (default 0)
89
+ 0 -- C-SVC
90
+ 1 -- nu-SVC
91
+ 2 -- one-class SVM
92
+ 3 -- epsilon-SVR
93
+ 4 -- nu-SVR
94
+ -t kernel_type : set type of kernel function (default 2)
95
+ 0 -- linear: u'*v
96
+ 1 -- polynomial: (gamma*u'*v + coef0)^degree
97
+ 2 -- radial basis function: exp(-gamma*|u-v|^2)
98
+ 3 -- sigmoid: tanh(gamma*u'*v + coef0)
99
+ 4 -- precomputed kernel (kernel values in training_set_file)
100
+ -d degree : set degree in kernel function (default 3)
101
+ -g gamma : set gamma in kernel function (default 1/num_features)
102
+ -r coef0 : set coef0 in kernel function (default 0)
103
+ -c cost : set the parameter C of C-SVC, epsilon-SVR, and nu-SVR (default 1)
104
+ -n nu : set the parameter nu of nu-SVC, one-class SVM, and nu-SVR (default 0.5)
105
+ -p epsilon : set the epsilon in loss function of epsilon-SVR (default 0.1)
106
+ -m cachesize : set cache memory size in MB (default 100)
107
+ -e epsilon : set tolerance of termination criterion (default 0.001)
108
+ -h shrinking : whether to use the shrinking heuristics, 0 or 1 (default 1)
109
+ -b probability_estimates : whether to train a SVC or SVR model for probability estimates, 0 or 1 (default 0)
110
+ -wi weight : set the parameter C of class i to weight*C, for C-SVC (default 1)
111
+ -v n: n-fold cross validation mode
112
+ -q : quiet mode (no outputs)
113
+ """
114
+ prob, param = None, None
115
+ if isinstance(arg1, (list, tuple)):
116
+ assert isinstance(arg2, (list, tuple))
117
+ y, x, options = arg1, arg2, arg3
118
+ prob = svm_problem(y, x)
119
+ param = svm_parameter(options)
120
+ elif isinstance(arg1, svm_problem):
121
+ prob = arg1
122
+ if isinstance(arg2, svm_parameter):
123
+ param = arg2
124
+ else:
125
+ param = svm_parameter(arg2)
126
+ if prob == None or param == None:
127
+ raise TypeError("Wrong types for the arguments")
128
+
129
+ if param.kernel_type == PRECOMPUTED:
130
+ for xi in prob.x_space:
131
+ idx, val = xi[0].index, xi[0].value
132
+ if xi[0].index != 0:
133
+ raise ValueError('Wrong input format: first column must be 0:sample_serial_number')
134
+ if val <= 0 or val > prob.n:
135
+ raise ValueError('Wrong input format: sample_serial_number out of range')
136
+
137
+ if param.gamma == 0 and prob.n > 0:
138
+ param.gamma = 1.0 / prob.n
139
+ libsvm.svm_set_print_string_function(param.print_func)
140
+ err_msg = libsvm.svm_check_parameter(prob, param)
141
+ if err_msg:
142
+ raise ValueError('Error: %s' % err_msg)
143
+
144
+ if param.cross_validation:
145
+ l, nr_fold = prob.l, param.nr_fold
146
+ target = (c_double * l)()
147
+ libsvm.svm_cross_validation(prob, param, nr_fold, target)
148
+ ACC, MSE, SCC = evaluations(prob.y[:l], target[:l])
149
+ if param.svm_type in [EPSILON_SVR, NU_SVR]:
150
+ print("Cross Validation Mean squared error = %g" % MSE)
151
+ print("Cross Validation Squared correlation coefficient = %g" % SCC)
152
+ return MSE
153
+ else:
154
+ print("Cross Validation Accuracy = %g%%" % ACC)
155
+ return ACC
156
+ else:
157
+ m = libsvm.svm_train(prob, param)
158
+ m = toPyModel(m)
159
+
160
+ # If prob is destroyed, data including SVs pointed by m can remain.
161
+ m.x_space = prob.x_space
162
+ return m
163
+
164
+ def svm_predict(y, x, m, options=""):
165
+ """
166
+ svm_predict(y, x, m [, "options"]) -> (p_labels, p_acc, p_vals)
167
+
168
+ Predict data (y, x) with the SVM model m.
169
+ "options":
170
+ -b probability_estimates: whether to predict probability estimates,
171
+ 0 or 1 (default 0); for one-class SVM only 0 is supported.
172
+
173
+ The return tuple contains
174
+ p_labels: a list of predicted labels
175
+ p_acc: a tuple including accuracy (for classification), mean-squared
176
+ error, and squared correlation coefficient (for regression).
177
+ p_vals: a list of decision values or probability estimates (if '-b 1'
178
+ is specified). If k is the number of classes, for decision values,
179
+ each element includes results of predicting k(k-1)/2 binary-class
180
+ SVMs. For probabilities, each element contains k values indicating
181
+ the probability that the testing instance is in each class.
182
+ Note that the order of classes here is the same as 'model.label'
183
+ field in the model structure.
184
+ """
185
+ predict_probability = 0
186
+ argv = options.split()
187
+ i = 0
188
+ while i < len(argv):
189
+ if argv[i] == '-b':
190
+ i += 1
191
+ predict_probability = int(argv[i])
192
+ else:
193
+ raise ValueError("Wrong options")
194
+ i+=1
195
+
196
+ svm_type = m.get_svm_type()
197
+ is_prob_model = m.is_probability_model()
198
+ nr_class = m.get_nr_class()
199
+ pred_labels = []
200
+ pred_values = []
201
+
202
+ if predict_probability:
203
+ if not is_prob_model:
204
+ raise ValueError("Model does not support probabiliy estimates")
205
+
206
+ if svm_type in [NU_SVR, EPSILON_SVR]:
207
+ print("Prob. model for test data: target value = predicted value + z,\n"
208
+ "z: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma=%g" % m.get_svr_probability());
209
+ nr_class = 0
210
+
211
+ prob_estimates = (c_double * nr_class)()
212
+ for xi in x:
213
+ xi, idx = gen_svm_nodearray(xi)
214
+ label = libsvm.svm_predict_probability(m, xi, prob_estimates)
215
+ values = prob_estimates[:nr_class]
216
+ pred_labels += [label]
217
+ pred_values += [values]
218
+ else:
219
+ if is_prob_model:
220
+ print("Model supports probability estimates, but disabled in predicton.")
221
+ if svm_type in (ONE_CLASS, EPSILON_SVR, NU_SVC):
222
+ nr_classifier = 1
223
+ else:
224
+ nr_classifier = nr_class*(nr_class-1)//2
225
+ dec_values = (c_double * nr_classifier)()
226
+ for xi in x:
227
+ xi, idx = gen_svm_nodearray(xi)
228
+ label = libsvm.svm_predict_values(m, xi, dec_values)
229
+ values = dec_values[:nr_classifier]
230
+ pred_labels += [label]
231
+ pred_values += [values]
232
+
233
+ ACC, MSE, SCC = evaluations(y, pred_labels)
234
+ l = len(y)
235
+ if svm_type in [EPSILON_SVR, NU_SVR]:
236
+ print("Mean squared error = %g (regression)" % MSE)
237
+ print("Squared correlation coefficient = %g (regression)" % SCC)
238
+ else:
239
+ print("Accuracy = %g%% (%d/%d) (classification)" % (ACC, int(l*ACC/100), l))
240
+
241
+ return pred_labels, (ACC, MSE, SCC), pred_values
242
+