eluka 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (73) hide show
  1. data/.document +5 -0
  2. data/DOCUMENTATION_STANDARDS +39 -0
  3. data/Gemfile +13 -0
  4. data/Gemfile.lock +20 -0
  5. data/LICENSE.txt +20 -0
  6. data/README.rdoc +19 -0
  7. data/Rakefile +69 -0
  8. data/VERSION +1 -0
  9. data/examples/example.rb +59 -0
  10. data/ext/libsvm/COPYRIGHT +31 -0
  11. data/ext/libsvm/FAQ.html +1749 -0
  12. data/ext/libsvm/Makefile +25 -0
  13. data/ext/libsvm/Makefile.win +33 -0
  14. data/ext/libsvm/README +733 -0
  15. data/ext/libsvm/extconf.rb +1 -0
  16. data/ext/libsvm/heart_scale +270 -0
  17. data/ext/libsvm/java/Makefile +25 -0
  18. data/ext/libsvm/java/libsvm.jar +0 -0
  19. data/ext/libsvm/java/libsvm/svm.java +2776 -0
  20. data/ext/libsvm/java/libsvm/svm.m4 +2776 -0
  21. data/ext/libsvm/java/libsvm/svm_model.java +21 -0
  22. data/ext/libsvm/java/libsvm/svm_node.java +6 -0
  23. data/ext/libsvm/java/libsvm/svm_parameter.java +47 -0
  24. data/ext/libsvm/java/libsvm/svm_print_interface.java +5 -0
  25. data/ext/libsvm/java/libsvm/svm_problem.java +7 -0
  26. data/ext/libsvm/java/svm_predict.java +163 -0
  27. data/ext/libsvm/java/svm_scale.java +350 -0
  28. data/ext/libsvm/java/svm_toy.java +471 -0
  29. data/ext/libsvm/java/svm_train.java +318 -0
  30. data/ext/libsvm/java/test_applet.html +1 -0
  31. data/ext/libsvm/python/Makefile +4 -0
  32. data/ext/libsvm/python/README +331 -0
  33. data/ext/libsvm/python/svm.py +259 -0
  34. data/ext/libsvm/python/svmutil.py +242 -0
  35. data/ext/libsvm/svm-predict.c +226 -0
  36. data/ext/libsvm/svm-scale.c +353 -0
  37. data/ext/libsvm/svm-toy/gtk/Makefile +22 -0
  38. data/ext/libsvm/svm-toy/gtk/callbacks.cpp +423 -0
  39. data/ext/libsvm/svm-toy/gtk/callbacks.h +54 -0
  40. data/ext/libsvm/svm-toy/gtk/interface.c +164 -0
  41. data/ext/libsvm/svm-toy/gtk/interface.h +14 -0
  42. data/ext/libsvm/svm-toy/gtk/main.c +23 -0
  43. data/ext/libsvm/svm-toy/gtk/svm-toy.glade +238 -0
  44. data/ext/libsvm/svm-toy/qt/Makefile +17 -0
  45. data/ext/libsvm/svm-toy/qt/svm-toy.cpp +413 -0
  46. data/ext/libsvm/svm-toy/windows/svm-toy.cpp +456 -0
  47. data/ext/libsvm/svm-train.c +376 -0
  48. data/ext/libsvm/svm.cpp +3060 -0
  49. data/ext/libsvm/svm.def +19 -0
  50. data/ext/libsvm/svm.h +105 -0
  51. data/ext/libsvm/svm.o +0 -0
  52. data/ext/libsvm/tools/README +149 -0
  53. data/ext/libsvm/tools/checkdata.py +108 -0
  54. data/ext/libsvm/tools/easy.py +79 -0
  55. data/ext/libsvm/tools/grid.py +359 -0
  56. data/ext/libsvm/tools/subset.py +146 -0
  57. data/ext/libsvm/windows/libsvm.dll +0 -0
  58. data/ext/libsvm/windows/svm-predict.exe +0 -0
  59. data/ext/libsvm/windows/svm-scale.exe +0 -0
  60. data/ext/libsvm/windows/svm-toy.exe +0 -0
  61. data/ext/libsvm/windows/svm-train.exe +0 -0
  62. data/lib/eluka.rb +10 -0
  63. data/lib/eluka/bijection.rb +23 -0
  64. data/lib/eluka/data_point.rb +36 -0
  65. data/lib/eluka/document.rb +47 -0
  66. data/lib/eluka/feature_vector.rb +86 -0
  67. data/lib/eluka/features.rb +31 -0
  68. data/lib/eluka/model.rb +129 -0
  69. data/lib/fselect.rb +321 -0
  70. data/lib/grid.rb +25 -0
  71. data/test/helper.rb +18 -0
  72. data/test/test_eluka.rb +7 -0
  73. metadata +214 -0
@@ -0,0 +1,359 @@
1
+ #!/usr/bin/env python
2
+
3
+
4
+
5
+ import os, sys, traceback
6
+ import getpass
7
+ from threading import Thread
8
+ from subprocess import *
9
+
10
+ if(sys.hexversion < 0x03000000):
11
+ import Queue
12
+ else:
13
+ import queue as Queue
14
+
15
+
16
+ # svmtrain and gnuplot executable
17
+
18
+ is_win32 = (sys.platform == 'win32')
19
+ if not is_win32:
20
+ svmtrain_exe = "../svm-train"
21
+ gnuplot_exe = "/usr/bin/gnuplot"
22
+ else:
23
+ # example for windows
24
+ svmtrain_exe = r"..\windows\svm-train.exe"
25
+ gnuplot_exe = r"c:\tmp\gnuplot\bin\pgnuplot.exe"
26
+
27
+ # global parameters and their default values
28
+
29
+ fold = 5
30
+ c_begin, c_end, c_step = -5, 15, 2
31
+ g_begin, g_end, g_step = 3, -15, -2
32
+ global dataset_pathname, dataset_title, pass_through_string
33
+ global out_filename, png_filename
34
+
35
+ # experimental
36
+
37
+ telnet_workers = []
38
+ ssh_workers = []
39
+ nr_local_worker = 1
40
+
41
+ # process command line options, set global parameters
42
+ def process_options(argv=sys.argv):
43
+
44
+ global fold
45
+ global c_begin, c_end, c_step
46
+ global g_begin, g_end, g_step
47
+ global dataset_pathname, dataset_title, pass_through_string
48
+ global svmtrain_exe, gnuplot_exe, gnuplot, out_filename, png_filename
49
+
50
+ usage = """\
51
+ Usage: grid.py [-log2c begin,end,step] [-log2g begin,end,step] [-v fold]
52
+ [-svmtrain pathname] [-gnuplot pathname] [-out pathname] [-png pathname]
53
+ [additional parameters for svm-train] dataset"""
54
+
55
+ if len(argv) < 2:
56
+ print(usage)
57
+ sys.exit(1)
58
+
59
+ dataset_pathname = argv[-1]
60
+ dataset_title = os.path.split(dataset_pathname)[1]
61
+ out_filename = '%s.out' % dataset_title
62
+ png_filename = '%s.png' % dataset_title
63
+ pass_through_options = []
64
+
65
+ i = 1
66
+ while i < len(argv) - 1:
67
+ if argv[i] == "-log2c":
68
+ i = i + 1
69
+ (c_begin,c_end,c_step) = map(float,argv[i].split(","))
70
+ elif argv[i] == "-log2g":
71
+ i = i + 1
72
+ (g_begin,g_end,g_step) = map(float,argv[i].split(","))
73
+ elif argv[i] == "-v":
74
+ i = i + 1
75
+ fold = argv[i]
76
+ elif argv[i] in ('-c','-g'):
77
+ print("Option -c and -g are renamed.")
78
+ print(usage)
79
+ sys.exit(1)
80
+ elif argv[i] == '-svmtrain':
81
+ i = i + 1
82
+ svmtrain_exe = argv[i]
83
+ elif argv[i] == '-gnuplot':
84
+ i = i + 1
85
+ gnuplot_exe = argv[i]
86
+ elif argv[i] == '-out':
87
+ i = i + 1
88
+ out_filename = argv[i]
89
+ elif argv[i] == '-png':
90
+ i = i + 1
91
+ png_filename = argv[i]
92
+ else:
93
+ pass_through_options.append(argv[i])
94
+ i = i + 1
95
+
96
+ pass_through_string = " ".join(pass_through_options)
97
+ assert os.path.exists(svmtrain_exe),"svm-train executable not found"
98
+ assert os.path.exists(gnuplot_exe),"gnuplot executable not found"
99
+ assert os.path.exists(dataset_pathname),"dataset not found"
100
+ gnuplot = Popen(gnuplot_exe,stdin = PIPE).stdin
101
+
102
+
103
+ def range_f(begin,end,step):
104
+ # like range, but works on non-integer too
105
+ seq = []
106
+ while True:
107
+ if step > 0 and begin > end: break
108
+ if step < 0 and begin < end: break
109
+ seq.append(begin)
110
+ begin = begin + step
111
+ return seq
112
+
113
+ def permute_sequence(seq):
114
+ n = len(seq)
115
+ if n <= 1: return seq
116
+
117
+ mid = int(n/2)
118
+ left = permute_sequence(seq[:mid])
119
+ right = permute_sequence(seq[mid+1:])
120
+
121
+ ret = [seq[mid]]
122
+ while left or right:
123
+ if left: ret.append(left.pop(0))
124
+ if right: ret.append(right.pop(0))
125
+
126
+ return ret
127
+
128
+ def redraw(db,best_param,tofile=False):
129
+ if len(db) == 0: return
130
+ begin_level = round(max(x[2] for x in db)) - 3
131
+ step_size = 0.5
132
+
133
+ best_log2c,best_log2g,best_rate = best_param
134
+
135
+ if tofile:
136
+ gnuplot.write( "set term png transparent small\n".encode())
137
+ gnuplot.write( ("set output \"%s\"\n" % png_filename.replace('\\','\\\\')).encode())
138
+ #gnuplot.write("set term postscript color solid\n".encode())
139
+ #gnuplot.write(("set output \"%s.ps\"\n" % dataset_title).encode())
140
+ elif is_win32:
141
+ gnuplot.write("set term windows\n".encode())
142
+ else:
143
+ gnuplot.write( "set term x11\n".encode())
144
+ gnuplot.write("set xlabel \"log2(C)\"\n".encode())
145
+ gnuplot.write("set ylabel \"log2(gamma)\"\n".encode())
146
+ gnuplot.write(("set xrange [%s:%s]\n" % (c_begin,c_end)).encode())
147
+ gnuplot.write(("set yrange [%s:%s]\n" % (g_begin,g_end)).encode())
148
+ gnuplot.write("set contour\n".encode())
149
+ gnuplot.write(("set cntrparam levels incremental %s,%s,100\n" % (begin_level,step_size)).encode())
150
+ gnuplot.write("unset surface\n".encode())
151
+ gnuplot.write("unset ztics\n".encode())
152
+ gnuplot.write("set view 0,0\n".encode())
153
+ gnuplot.write(("set title \"%s\"\n" % dataset_title).encode())
154
+ gnuplot.write("unset label\n".encode())
155
+ gnuplot.write(("set label \"Best log2(C) = %s log2(gamma) = %s accuracy = %s%%\" \
156
+ at screen 0.5,0.85 center\n" % \
157
+ (best_log2c, best_log2g, best_rate)).encode())
158
+ gnuplot.write(("set label \"C = %s gamma = %s\""
159
+ " at screen 0.5,0.8 center\n" % (2**best_log2c, 2**best_log2g)).encode())
160
+ gnuplot.write("splot \"-\" with lines\n".encode())
161
+
162
+
163
+
164
+
165
+ db.sort(key = lambda x:(x[0], -x[1]))
166
+
167
+ prevc = db[0][0]
168
+ for line in db:
169
+ if prevc != line[0]:
170
+ gnuplot.write("\n".encode())
171
+ prevc = line[0]
172
+ gnuplot.write(("%s %s %s\n" % line).encode())
173
+ gnuplot.write("e\n".encode())
174
+ gnuplot.write("\n".encode()) # force gnuplot back to prompt when term set failure
175
+ gnuplot.flush()
176
+
177
+
178
+ def calculate_jobs():
179
+ c_seq = permute_sequence(range_f(c_begin,c_end,c_step))
180
+ g_seq = permute_sequence(range_f(g_begin,g_end,g_step))
181
+ nr_c = float(len(c_seq))
182
+ nr_g = float(len(g_seq))
183
+ i = 0
184
+ j = 0
185
+ jobs = []
186
+
187
+ while i < nr_c or j < nr_g:
188
+ if i/nr_c < j/nr_g:
189
+ # increase C resolution
190
+ line = []
191
+ for k in range(0,j):
192
+ line.append((c_seq[i],g_seq[k]))
193
+ i = i + 1
194
+ jobs.append(line)
195
+ else:
196
+ # increase g resolution
197
+ line = []
198
+ for k in range(0,i):
199
+ line.append((c_seq[k],g_seq[j]))
200
+ j = j + 1
201
+ jobs.append(line)
202
+ return jobs
203
+
204
+ class WorkerStopToken: # used to notify the worker to stop
205
+ pass
206
+
207
+ class Worker(Thread):
208
+ def __init__(self,name,job_queue,result_queue):
209
+ Thread.__init__(self)
210
+ self.name = name
211
+ self.job_queue = job_queue
212
+ self.result_queue = result_queue
213
+ def run(self):
214
+ while True:
215
+ (cexp,gexp) = self.job_queue.get()
216
+ if cexp is WorkerStopToken:
217
+ self.job_queue.put((cexp,gexp))
218
+ # print 'worker %s stop.' % self.name
219
+ break
220
+ try:
221
+ rate = self.run_one(2.0**cexp,2.0**gexp)
222
+ if rate is None: raise RuntimeError("get no rate")
223
+ except:
224
+ # we failed, let others do that and we just quit
225
+
226
+ traceback.print_exception(sys.exc_info()[0], sys.exc_info()[1], sys.exc_info()[2])
227
+
228
+ self.job_queue.put((cexp,gexp))
229
+ print('worker %s quit.' % self.name)
230
+ break
231
+ else:
232
+ self.result_queue.put((self.name,cexp,gexp,rate))
233
+
234
+ class LocalWorker(Worker):
235
+ def run_one(self,c,g):
236
+ cmdline = '%s -c %s -g %s -v %s %s %s' % \
237
+ (svmtrain_exe,c,g,fold,pass_through_string,dataset_pathname)
238
+ result = Popen(cmdline,shell=True,stdout=PIPE).stdout
239
+ for line in result.readlines():
240
+ if str(line).find("Cross") != -1:
241
+ return float(line.split()[-1][0:-1])
242
+
243
+ class SSHWorker(Worker):
244
+ def __init__(self,name,job_queue,result_queue,host):
245
+ Worker.__init__(self,name,job_queue,result_queue)
246
+ self.host = host
247
+ self.cwd = os.getcwd()
248
+ def run_one(self,c,g):
249
+ cmdline = 'ssh -x %s "cd %s; %s -c %s -g %s -v %s %s %s"' % \
250
+ (self.host,self.cwd,
251
+ svmtrain_exe,c,g,fold,pass_through_string,dataset_pathname)
252
+ result = Popen(cmdline,shell=True,stdout=PIPE).stdout
253
+ for line in result.readlines():
254
+ if str(line).find("Cross") != -1:
255
+ return float(line.split()[-1][0:-1])
256
+
257
+ class TelnetWorker(Worker):
258
+ def __init__(self,name,job_queue,result_queue,host,username,password):
259
+ Worker.__init__(self,name,job_queue,result_queue)
260
+ self.host = host
261
+ self.username = username
262
+ self.password = password
263
+ def run(self):
264
+ import telnetlib
265
+ self.tn = tn = telnetlib.Telnet(self.host)
266
+ tn.read_until("login: ")
267
+ tn.write(self.username + "\n")
268
+ tn.read_until("Password: ")
269
+ tn.write(self.password + "\n")
270
+
271
+ # XXX: how to know whether login is successful?
272
+ tn.read_until(self.username)
273
+ #
274
+ print('login ok', self.host)
275
+ tn.write("cd "+os.getcwd()+"\n")
276
+ Worker.run(self)
277
+ tn.write("exit\n")
278
+ def run_one(self,c,g):
279
+ cmdline = '%s -c %s -g %s -v %s %s %s' % \
280
+ (svmtrain_exe,c,g,fold,pass_through_string,dataset_pathname)
281
+ result = self.tn.write(cmdline+'\n')
282
+ (idx,matchm,output) = self.tn.expect(['Cross.*\n'])
283
+ for line in output.split('\n'):
284
+ if str(line).find("Cross") != -1:
285
+ return float(line.split()[-1][0:-1])
286
+
287
+ def main():
288
+
289
+ # set parameters
290
+
291
+ process_options()
292
+
293
+ # put jobs in queue
294
+
295
+ jobs = calculate_jobs()
296
+ job_queue = Queue.Queue(0)
297
+ result_queue = Queue.Queue(0)
298
+
299
+ for line in jobs:
300
+ for (c,g) in line:
301
+ job_queue.put((c,g))
302
+
303
+ job_queue._put = job_queue.queue.appendleft
304
+
305
+
306
+ # fire telnet workers
307
+
308
+ if telnet_workers:
309
+ nr_telnet_worker = len(telnet_workers)
310
+ username = getpass.getuser()
311
+ password = getpass.getpass()
312
+ for host in telnet_workers:
313
+ TelnetWorker(host,job_queue,result_queue,
314
+ host,username,password).start()
315
+
316
+ # fire ssh workers
317
+
318
+ if ssh_workers:
319
+ for host in ssh_workers:
320
+ SSHWorker(host,job_queue,result_queue,host).start()
321
+
322
+ # fire local workers
323
+
324
+ for i in range(nr_local_worker):
325
+ LocalWorker('local',job_queue,result_queue).start()
326
+
327
+ # gather results
328
+
329
+ done_jobs = {}
330
+
331
+
332
+ result_file = open(out_filename, 'w')
333
+
334
+
335
+ db = []
336
+ best_rate = -1
337
+ best_c1,best_g1 = None,None
338
+
339
+ for line in jobs:
340
+ for (c,g) in line:
341
+ while (c, g) not in done_jobs:
342
+ (worker,c1,g1,rate) = result_queue.get()
343
+ done_jobs[(c1,g1)] = rate
344
+ result_file.write('%s %s %s\n' %(c1,g1,rate))
345
+ result_file.flush()
346
+ if (rate > best_rate) or (rate==best_rate and g1==best_g1 and c1<best_c1):
347
+ best_rate = rate
348
+ best_c1,best_g1=c1,g1
349
+ best_c = 2.0**c1
350
+ best_g = 2.0**g1
351
+ print("[%s] %s %s %s (best c=%s, g=%s, rate=%s)" % \
352
+ (worker,c1,g1,rate, best_c, best_g, best_rate))
353
+ db.append((c,g,done_jobs[(c,g)]))
354
+ redraw(db,[best_c1, best_g1, best_rate])
355
+ redraw(db,[best_c1, best_g1, best_rate],True)
356
+
357
+ job_queue.put((WorkerStopToken,None))
358
+ print("%s %s %s" % (best_c, best_g, best_rate))
359
+ main()
@@ -0,0 +1,146 @@
1
+ #!/usr/bin/env python
2
+ from sys import argv, exit, stdout, stderr
3
+ from random import randint
4
+
5
+ method = 0
6
+ global n
7
+ global dataset_filename
8
+ subset_filename = ""
9
+ rest_filename = ""
10
+
11
+ def exit_with_help():
12
+ print("""\
13
+ Usage: %s [options] dataset number [output1] [output2]
14
+
15
+ This script selects a subset of the given dataset.
16
+
17
+ options:
18
+ -s method : method of selection (default 0)
19
+ 0 -- stratified selection (classification only)
20
+ 1 -- random selection
21
+
22
+ output1 : the subset (optional)
23
+ output2 : rest of the data (optional)
24
+ If output1 is omitted, the subset will be printed on the screen.""" % argv[0])
25
+ exit(1)
26
+
27
+ def process_options():
28
+ global method, n
29
+ global dataset_filename, subset_filename, rest_filename
30
+
31
+ argc = len(argv)
32
+ if argc < 3:
33
+ exit_with_help()
34
+
35
+ i = 1
36
+ while i < len(argv):
37
+ if argv[i][0] != "-":
38
+ break
39
+ if argv[i] == "-s":
40
+ i = i + 1
41
+ method = int(argv[i])
42
+ if method < 0 or method > 1:
43
+ print("Unknown selection method %d" % (method))
44
+ exit_with_help()
45
+ i = i + 1
46
+
47
+ dataset_filename = argv[i]
48
+ n = int(argv[i+1])
49
+ if i+2 < argc:
50
+ subset_filename = argv[i+2]
51
+ if i+3 < argc:
52
+ rest_filename = argv[i+3]
53
+
54
+ def main():
55
+ class Label:
56
+ def __init__(self, label, index, selected):
57
+ self.label = label
58
+ self.index = index
59
+ self.selected = selected
60
+
61
+ process_options()
62
+
63
+ # get labels
64
+ i = 0
65
+ labels = []
66
+ f = open(dataset_filename, 'r')
67
+ for line in f:
68
+ labels.append(Label(float((line.split())[0]), i, 0))
69
+ i = i + 1
70
+ f.close()
71
+ l = i
72
+
73
+ # determine where to output
74
+ if subset_filename != "":
75
+ file1 = open(subset_filename, 'w')
76
+ else:
77
+ file1 = stdout
78
+ split = 0
79
+ if rest_filename != "":
80
+ split = 1
81
+ file2 = open(rest_filename, 'w')
82
+
83
+ # select the subset
84
+ warning = 0
85
+ if method == 0: # stratified
86
+ labels.sort(key = lambda x: x.label)
87
+
88
+ label_end = labels[l-1].label + 1
89
+ labels.append(Label(label_end, l, 0))
90
+
91
+ begin = 0
92
+ label = labels[begin].label
93
+ for i in range(l+1):
94
+ new_label = labels[i].label
95
+ if new_label != label:
96
+ nr_class = i - begin
97
+ k = i*n//l - begin*n//l
98
+ # at least one instance per class
99
+ if k == 0:
100
+ k = 1
101
+ warning = warning + 1
102
+ for j in range(nr_class):
103
+ if randint(0, nr_class-j-1) < k:
104
+ labels[begin+j].selected = 1
105
+ k = k - 1
106
+ begin = i
107
+ label = new_label
108
+ elif method == 1: # random
109
+ k = n
110
+ for i in range(l):
111
+ if randint(0,l-i-1) < k:
112
+ labels[i].selected = 1
113
+ k = k - 1
114
+ i = i + 1
115
+
116
+ # output
117
+ i = 0
118
+ if method == 0:
119
+ labels.sort(key = lambda x: int(x.index))
120
+
121
+ f = open(dataset_filename, 'r')
122
+ for line in f:
123
+ if labels[i].selected == 1:
124
+ file1.write(line)
125
+ else:
126
+ if split == 1:
127
+ file2.write(line)
128
+ i = i + 1
129
+
130
+ if warning > 0:
131
+ stderr.write("""\
132
+ Warning:
133
+ 1. You may have regression data. Please use -s 1.
134
+ 2. Classification data unbalanced or too small. We select at least 1 per class.
135
+ The subset thus contains %d instances.
136
+ """ % (n+warning))
137
+
138
+ # cleanup
139
+ f.close()
140
+
141
+ file1.close()
142
+
143
+ if split == 1:
144
+ file2.close()
145
+
146
+ main()