eluka 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. data/.document +5 -0
  2. data/DOCUMENTATION_STANDARDS +39 -0
  3. data/Gemfile +13 -0
  4. data/Gemfile.lock +20 -0
  5. data/LICENSE.txt +20 -0
  6. data/README.rdoc +19 -0
  7. data/Rakefile +69 -0
  8. data/VERSION +1 -0
  9. data/examples/example.rb +59 -0
  10. data/ext/libsvm/COPYRIGHT +31 -0
  11. data/ext/libsvm/FAQ.html +1749 -0
  12. data/ext/libsvm/Makefile +25 -0
  13. data/ext/libsvm/Makefile.win +33 -0
  14. data/ext/libsvm/README +733 -0
  15. data/ext/libsvm/extconf.rb +1 -0
  16. data/ext/libsvm/heart_scale +270 -0
  17. data/ext/libsvm/java/Makefile +25 -0
  18. data/ext/libsvm/java/libsvm.jar +0 -0
  19. data/ext/libsvm/java/libsvm/svm.java +2776 -0
  20. data/ext/libsvm/java/libsvm/svm.m4 +2776 -0
  21. data/ext/libsvm/java/libsvm/svm_model.java +21 -0
  22. data/ext/libsvm/java/libsvm/svm_node.java +6 -0
  23. data/ext/libsvm/java/libsvm/svm_parameter.java +47 -0
  24. data/ext/libsvm/java/libsvm/svm_print_interface.java +5 -0
  25. data/ext/libsvm/java/libsvm/svm_problem.java +7 -0
  26. data/ext/libsvm/java/svm_predict.java +163 -0
  27. data/ext/libsvm/java/svm_scale.java +350 -0
  28. data/ext/libsvm/java/svm_toy.java +471 -0
  29. data/ext/libsvm/java/svm_train.java +318 -0
  30. data/ext/libsvm/java/test_applet.html +1 -0
  31. data/ext/libsvm/python/Makefile +4 -0
  32. data/ext/libsvm/python/README +331 -0
  33. data/ext/libsvm/python/svm.py +259 -0
  34. data/ext/libsvm/python/svmutil.py +242 -0
  35. data/ext/libsvm/svm-predict.c +226 -0
  36. data/ext/libsvm/svm-scale.c +353 -0
  37. data/ext/libsvm/svm-toy/gtk/Makefile +22 -0
  38. data/ext/libsvm/svm-toy/gtk/callbacks.cpp +423 -0
  39. data/ext/libsvm/svm-toy/gtk/callbacks.h +54 -0
  40. data/ext/libsvm/svm-toy/gtk/interface.c +164 -0
  41. data/ext/libsvm/svm-toy/gtk/interface.h +14 -0
  42. data/ext/libsvm/svm-toy/gtk/main.c +23 -0
  43. data/ext/libsvm/svm-toy/gtk/svm-toy.glade +238 -0
  44. data/ext/libsvm/svm-toy/qt/Makefile +17 -0
  45. data/ext/libsvm/svm-toy/qt/svm-toy.cpp +413 -0
  46. data/ext/libsvm/svm-toy/windows/svm-toy.cpp +456 -0
  47. data/ext/libsvm/svm-train.c +376 -0
  48. data/ext/libsvm/svm.cpp +3060 -0
  49. data/ext/libsvm/svm.def +19 -0
  50. data/ext/libsvm/svm.h +105 -0
  51. data/ext/libsvm/svm.o +0 -0
  52. data/ext/libsvm/tools/README +149 -0
  53. data/ext/libsvm/tools/checkdata.py +108 -0
  54. data/ext/libsvm/tools/easy.py +79 -0
  55. data/ext/libsvm/tools/grid.py +359 -0
  56. data/ext/libsvm/tools/subset.py +146 -0
  57. data/ext/libsvm/windows/libsvm.dll +0 -0
  58. data/ext/libsvm/windows/svm-predict.exe +0 -0
  59. data/ext/libsvm/windows/svm-scale.exe +0 -0
  60. data/ext/libsvm/windows/svm-toy.exe +0 -0
  61. data/ext/libsvm/windows/svm-train.exe +0 -0
  62. data/lib/eluka.rb +10 -0
  63. data/lib/eluka/bijection.rb +23 -0
  64. data/lib/eluka/data_point.rb +36 -0
  65. data/lib/eluka/document.rb +47 -0
  66. data/lib/eluka/feature_vector.rb +86 -0
  67. data/lib/eluka/features.rb +31 -0
  68. data/lib/eluka/model.rb +129 -0
  69. data/lib/fselect.rb +321 -0
  70. data/lib/grid.rb +25 -0
  71. data/test/helper.rb +18 -0
  72. data/test/test_eluka.rb +7 -0
  73. metadata +214 -0
@@ -0,0 +1,359 @@
1
+ #!/usr/bin/env python
2
+
3
+
4
+
5
+ import os, sys, traceback
6
+ import getpass
7
+ from threading import Thread
8
+ from subprocess import *
9
+
10
+ if(sys.hexversion < 0x03000000):
11
+ import Queue
12
+ else:
13
+ import queue as Queue
14
+
15
+
16
+ # svmtrain and gnuplot executable
17
+
18
+ is_win32 = (sys.platform == 'win32')
19
+ if not is_win32:
20
+ svmtrain_exe = "../svm-train"
21
+ gnuplot_exe = "/usr/bin/gnuplot"
22
+ else:
23
+ # example for windows
24
+ svmtrain_exe = r"..\windows\svm-train.exe"
25
+ gnuplot_exe = r"c:\tmp\gnuplot\bin\pgnuplot.exe"
26
+
27
+ # global parameters and their default values
28
+
29
+ fold = 5
30
+ c_begin, c_end, c_step = -5, 15, 2
31
+ g_begin, g_end, g_step = 3, -15, -2
32
+ global dataset_pathname, dataset_title, pass_through_string
33
+ global out_filename, png_filename
34
+
35
+ # experimental
36
+
37
+ telnet_workers = []
38
+ ssh_workers = []
39
+ nr_local_worker = 1
40
+
41
+ # process command line options, set global parameters
42
+ def process_options(argv=sys.argv):
43
+
44
+ global fold
45
+ global c_begin, c_end, c_step
46
+ global g_begin, g_end, g_step
47
+ global dataset_pathname, dataset_title, pass_through_string
48
+ global svmtrain_exe, gnuplot_exe, gnuplot, out_filename, png_filename
49
+
50
+ usage = """\
51
+ Usage: grid.py [-log2c begin,end,step] [-log2g begin,end,step] [-v fold]
52
+ [-svmtrain pathname] [-gnuplot pathname] [-out pathname] [-png pathname]
53
+ [additional parameters for svm-train] dataset"""
54
+
55
+ if len(argv) < 2:
56
+ print(usage)
57
+ sys.exit(1)
58
+
59
+ dataset_pathname = argv[-1]
60
+ dataset_title = os.path.split(dataset_pathname)[1]
61
+ out_filename = '%s.out' % dataset_title
62
+ png_filename = '%s.png' % dataset_title
63
+ pass_through_options = []
64
+
65
+ i = 1
66
+ while i < len(argv) - 1:
67
+ if argv[i] == "-log2c":
68
+ i = i + 1
69
+ (c_begin,c_end,c_step) = map(float,argv[i].split(","))
70
+ elif argv[i] == "-log2g":
71
+ i = i + 1
72
+ (g_begin,g_end,g_step) = map(float,argv[i].split(","))
73
+ elif argv[i] == "-v":
74
+ i = i + 1
75
+ fold = argv[i]
76
+ elif argv[i] in ('-c','-g'):
77
+ print("Option -c and -g are renamed.")
78
+ print(usage)
79
+ sys.exit(1)
80
+ elif argv[i] == '-svmtrain':
81
+ i = i + 1
82
+ svmtrain_exe = argv[i]
83
+ elif argv[i] == '-gnuplot':
84
+ i = i + 1
85
+ gnuplot_exe = argv[i]
86
+ elif argv[i] == '-out':
87
+ i = i + 1
88
+ out_filename = argv[i]
89
+ elif argv[i] == '-png':
90
+ i = i + 1
91
+ png_filename = argv[i]
92
+ else:
93
+ pass_through_options.append(argv[i])
94
+ i = i + 1
95
+
96
+ pass_through_string = " ".join(pass_through_options)
97
+ assert os.path.exists(svmtrain_exe),"svm-train executable not found"
98
+ assert os.path.exists(gnuplot_exe),"gnuplot executable not found"
99
+ assert os.path.exists(dataset_pathname),"dataset not found"
100
+ gnuplot = Popen(gnuplot_exe,stdin = PIPE).stdin
101
+
102
+
103
+ def range_f(begin,end,step):
104
+ # like range, but works on non-integer too
105
+ seq = []
106
+ while True:
107
+ if step > 0 and begin > end: break
108
+ if step < 0 and begin < end: break
109
+ seq.append(begin)
110
+ begin = begin + step
111
+ return seq
112
+
113
+ def permute_sequence(seq):
114
+ n = len(seq)
115
+ if n <= 1: return seq
116
+
117
+ mid = int(n/2)
118
+ left = permute_sequence(seq[:mid])
119
+ right = permute_sequence(seq[mid+1:])
120
+
121
+ ret = [seq[mid]]
122
+ while left or right:
123
+ if left: ret.append(left.pop(0))
124
+ if right: ret.append(right.pop(0))
125
+
126
+ return ret
127
+
128
+ def redraw(db,best_param,tofile=False):
129
+ if len(db) == 0: return
130
+ begin_level = round(max(x[2] for x in db)) - 3
131
+ step_size = 0.5
132
+
133
+ best_log2c,best_log2g,best_rate = best_param
134
+
135
+ if tofile:
136
+ gnuplot.write( "set term png transparent small\n".encode())
137
+ gnuplot.write( ("set output \"%s\"\n" % png_filename.replace('\\','\\\\')).encode())
138
+ #gnuplot.write("set term postscript color solid\n".encode())
139
+ #gnuplot.write(("set output \"%s.ps\"\n" % dataset_title).encode())
140
+ elif is_win32:
141
+ gnuplot.write("set term windows\n".encode())
142
+ else:
143
+ gnuplot.write( "set term x11\n".encode())
144
+ gnuplot.write("set xlabel \"log2(C)\"\n".encode())
145
+ gnuplot.write("set ylabel \"log2(gamma)\"\n".encode())
146
+ gnuplot.write(("set xrange [%s:%s]\n" % (c_begin,c_end)).encode())
147
+ gnuplot.write(("set yrange [%s:%s]\n" % (g_begin,g_end)).encode())
148
+ gnuplot.write("set contour\n".encode())
149
+ gnuplot.write(("set cntrparam levels incremental %s,%s,100\n" % (begin_level,step_size)).encode())
150
+ gnuplot.write("unset surface\n".encode())
151
+ gnuplot.write("unset ztics\n".encode())
152
+ gnuplot.write("set view 0,0\n".encode())
153
+ gnuplot.write(("set title \"%s\"\n" % dataset_title).encode())
154
+ gnuplot.write("unset label\n".encode())
155
+ gnuplot.write(("set label \"Best log2(C) = %s log2(gamma) = %s accuracy = %s%%\" \
156
+ at screen 0.5,0.85 center\n" % \
157
+ (best_log2c, best_log2g, best_rate)).encode())
158
+ gnuplot.write(("set label \"C = %s gamma = %s\""
159
+ " at screen 0.5,0.8 center\n" % (2**best_log2c, 2**best_log2g)).encode())
160
+ gnuplot.write("splot \"-\" with lines\n".encode())
161
+
162
+
163
+
164
+
165
+ db.sort(key = lambda x:(x[0], -x[1]))
166
+
167
+ prevc = db[0][0]
168
+ for line in db:
169
+ if prevc != line[0]:
170
+ gnuplot.write("\n".encode())
171
+ prevc = line[0]
172
+ gnuplot.write(("%s %s %s\n" % line).encode())
173
+ gnuplot.write("e\n".encode())
174
+ gnuplot.write("\n".encode()) # force gnuplot back to prompt when term set failure
175
+ gnuplot.flush()
176
+
177
+
178
+ def calculate_jobs():
179
+ c_seq = permute_sequence(range_f(c_begin,c_end,c_step))
180
+ g_seq = permute_sequence(range_f(g_begin,g_end,g_step))
181
+ nr_c = float(len(c_seq))
182
+ nr_g = float(len(g_seq))
183
+ i = 0
184
+ j = 0
185
+ jobs = []
186
+
187
+ while i < nr_c or j < nr_g:
188
+ if i/nr_c < j/nr_g:
189
+ # increase C resolution
190
+ line = []
191
+ for k in range(0,j):
192
+ line.append((c_seq[i],g_seq[k]))
193
+ i = i + 1
194
+ jobs.append(line)
195
+ else:
196
+ # increase g resolution
197
+ line = []
198
+ for k in range(0,i):
199
+ line.append((c_seq[k],g_seq[j]))
200
+ j = j + 1
201
+ jobs.append(line)
202
+ return jobs
203
+
204
+ class WorkerStopToken: # used to notify the worker to stop
205
+ pass
206
+
207
+ class Worker(Thread):
208
+ def __init__(self,name,job_queue,result_queue):
209
+ Thread.__init__(self)
210
+ self.name = name
211
+ self.job_queue = job_queue
212
+ self.result_queue = result_queue
213
+ def run(self):
214
+ while True:
215
+ (cexp,gexp) = self.job_queue.get()
216
+ if cexp is WorkerStopToken:
217
+ self.job_queue.put((cexp,gexp))
218
+ # print 'worker %s stop.' % self.name
219
+ break
220
+ try:
221
+ rate = self.run_one(2.0**cexp,2.0**gexp)
222
+ if rate is None: raise RuntimeError("get no rate")
223
+ except:
224
+ # we failed, let others do that and we just quit
225
+
226
+ traceback.print_exception(sys.exc_info()[0], sys.exc_info()[1], sys.exc_info()[2])
227
+
228
+ self.job_queue.put((cexp,gexp))
229
+ print('worker %s quit.' % self.name)
230
+ break
231
+ else:
232
+ self.result_queue.put((self.name,cexp,gexp,rate))
233
+
234
+ class LocalWorker(Worker):
235
+ def run_one(self,c,g):
236
+ cmdline = '%s -c %s -g %s -v %s %s %s' % \
237
+ (svmtrain_exe,c,g,fold,pass_through_string,dataset_pathname)
238
+ result = Popen(cmdline,shell=True,stdout=PIPE).stdout
239
+ for line in result.readlines():
240
+ if str(line).find("Cross") != -1:
241
+ return float(line.split()[-1][0:-1])
242
+
243
+ class SSHWorker(Worker):
244
+ def __init__(self,name,job_queue,result_queue,host):
245
+ Worker.__init__(self,name,job_queue,result_queue)
246
+ self.host = host
247
+ self.cwd = os.getcwd()
248
+ def run_one(self,c,g):
249
+ cmdline = 'ssh -x %s "cd %s; %s -c %s -g %s -v %s %s %s"' % \
250
+ (self.host,self.cwd,
251
+ svmtrain_exe,c,g,fold,pass_through_string,dataset_pathname)
252
+ result = Popen(cmdline,shell=True,stdout=PIPE).stdout
253
+ for line in result.readlines():
254
+ if str(line).find("Cross") != -1:
255
+ return float(line.split()[-1][0:-1])
256
+
257
+ class TelnetWorker(Worker):
258
+ def __init__(self,name,job_queue,result_queue,host,username,password):
259
+ Worker.__init__(self,name,job_queue,result_queue)
260
+ self.host = host
261
+ self.username = username
262
+ self.password = password
263
+ def run(self):
264
+ import telnetlib
265
+ self.tn = tn = telnetlib.Telnet(self.host)
266
+ tn.read_until("login: ")
267
+ tn.write(self.username + "\n")
268
+ tn.read_until("Password: ")
269
+ tn.write(self.password + "\n")
270
+
271
+ # XXX: how to know whether login is successful?
272
+ tn.read_until(self.username)
273
+ #
274
+ print('login ok', self.host)
275
+ tn.write("cd "+os.getcwd()+"\n")
276
+ Worker.run(self)
277
+ tn.write("exit\n")
278
+ def run_one(self,c,g):
279
+ cmdline = '%s -c %s -g %s -v %s %s %s' % \
280
+ (svmtrain_exe,c,g,fold,pass_through_string,dataset_pathname)
281
+ result = self.tn.write(cmdline+'\n')
282
+ (idx,matchm,output) = self.tn.expect(['Cross.*\n'])
283
+ for line in output.split('\n'):
284
+ if str(line).find("Cross") != -1:
285
+ return float(line.split()[-1][0:-1])
286
+
287
+ def main():
288
+
289
+ # set parameters
290
+
291
+ process_options()
292
+
293
+ # put jobs in queue
294
+
295
+ jobs = calculate_jobs()
296
+ job_queue = Queue.Queue(0)
297
+ result_queue = Queue.Queue(0)
298
+
299
+ for line in jobs:
300
+ for (c,g) in line:
301
+ job_queue.put((c,g))
302
+
303
+ job_queue._put = job_queue.queue.appendleft
304
+
305
+
306
+ # fire telnet workers
307
+
308
+ if telnet_workers:
309
+ nr_telnet_worker = len(telnet_workers)
310
+ username = getpass.getuser()
311
+ password = getpass.getpass()
312
+ for host in telnet_workers:
313
+ TelnetWorker(host,job_queue,result_queue,
314
+ host,username,password).start()
315
+
316
+ # fire ssh workers
317
+
318
+ if ssh_workers:
319
+ for host in ssh_workers:
320
+ SSHWorker(host,job_queue,result_queue,host).start()
321
+
322
+ # fire local workers
323
+
324
+ for i in range(nr_local_worker):
325
+ LocalWorker('local',job_queue,result_queue).start()
326
+
327
+ # gather results
328
+
329
+ done_jobs = {}
330
+
331
+
332
+ result_file = open(out_filename, 'w')
333
+
334
+
335
+ db = []
336
+ best_rate = -1
337
+ best_c1,best_g1 = None,None
338
+
339
+ for line in jobs:
340
+ for (c,g) in line:
341
+ while (c, g) not in done_jobs:
342
+ (worker,c1,g1,rate) = result_queue.get()
343
+ done_jobs[(c1,g1)] = rate
344
+ result_file.write('%s %s %s\n' %(c1,g1,rate))
345
+ result_file.flush()
346
+ if (rate > best_rate) or (rate==best_rate and g1==best_g1 and c1<best_c1):
347
+ best_rate = rate
348
+ best_c1,best_g1=c1,g1
349
+ best_c = 2.0**c1
350
+ best_g = 2.0**g1
351
+ print("[%s] %s %s %s (best c=%s, g=%s, rate=%s)" % \
352
+ (worker,c1,g1,rate, best_c, best_g, best_rate))
353
+ db.append((c,g,done_jobs[(c,g)]))
354
+ redraw(db,[best_c1, best_g1, best_rate])
355
+ redraw(db,[best_c1, best_g1, best_rate],True)
356
+
357
+ job_queue.put((WorkerStopToken,None))
358
+ print("%s %s %s" % (best_c, best_g, best_rate))
359
+ main()
@@ -0,0 +1,146 @@
1
+ #!/usr/bin/env python
2
+ from sys import argv, exit, stdout, stderr
3
+ from random import randint
4
+
5
+ method = 0
6
+ global n
7
+ global dataset_filename
8
+ subset_filename = ""
9
+ rest_filename = ""
10
+
11
+ def exit_with_help():
12
+ print("""\
13
+ Usage: %s [options] dataset number [output1] [output2]
14
+
15
+ This script selects a subset of the given dataset.
16
+
17
+ options:
18
+ -s method : method of selection (default 0)
19
+ 0 -- stratified selection (classification only)
20
+ 1 -- random selection
21
+
22
+ output1 : the subset (optional)
23
+ output2 : rest of the data (optional)
24
+ If output1 is omitted, the subset will be printed on the screen.""" % argv[0])
25
+ exit(1)
26
+
27
+ def process_options():
28
+ global method, n
29
+ global dataset_filename, subset_filename, rest_filename
30
+
31
+ argc = len(argv)
32
+ if argc < 3:
33
+ exit_with_help()
34
+
35
+ i = 1
36
+ while i < len(argv):
37
+ if argv[i][0] != "-":
38
+ break
39
+ if argv[i] == "-s":
40
+ i = i + 1
41
+ method = int(argv[i])
42
+ if method < 0 or method > 1:
43
+ print("Unknown selection method %d" % (method))
44
+ exit_with_help()
45
+ i = i + 1
46
+
47
+ dataset_filename = argv[i]
48
+ n = int(argv[i+1])
49
+ if i+2 < argc:
50
+ subset_filename = argv[i+2]
51
+ if i+3 < argc:
52
+ rest_filename = argv[i+3]
53
+
54
+ def main():
55
+ class Label:
56
+ def __init__(self, label, index, selected):
57
+ self.label = label
58
+ self.index = index
59
+ self.selected = selected
60
+
61
+ process_options()
62
+
63
+ # get labels
64
+ i = 0
65
+ labels = []
66
+ f = open(dataset_filename, 'r')
67
+ for line in f:
68
+ labels.append(Label(float((line.split())[0]), i, 0))
69
+ i = i + 1
70
+ f.close()
71
+ l = i
72
+
73
+ # determine where to output
74
+ if subset_filename != "":
75
+ file1 = open(subset_filename, 'w')
76
+ else:
77
+ file1 = stdout
78
+ split = 0
79
+ if rest_filename != "":
80
+ split = 1
81
+ file2 = open(rest_filename, 'w')
82
+
83
+ # select the subset
84
+ warning = 0
85
+ if method == 0: # stratified
86
+ labels.sort(key = lambda x: x.label)
87
+
88
+ label_end = labels[l-1].label + 1
89
+ labels.append(Label(label_end, l, 0))
90
+
91
+ begin = 0
92
+ label = labels[begin].label
93
+ for i in range(l+1):
94
+ new_label = labels[i].label
95
+ if new_label != label:
96
+ nr_class = i - begin
97
+ k = i*n//l - begin*n//l
98
+ # at least one instance per class
99
+ if k == 0:
100
+ k = 1
101
+ warning = warning + 1
102
+ for j in range(nr_class):
103
+ if randint(0, nr_class-j-1) < k:
104
+ labels[begin+j].selected = 1
105
+ k = k - 1
106
+ begin = i
107
+ label = new_label
108
+ elif method == 1: # random
109
+ k = n
110
+ for i in range(l):
111
+ if randint(0,l-i-1) < k:
112
+ labels[i].selected = 1
113
+ k = k - 1
114
+ i = i + 1
115
+
116
+ # output
117
+ i = 0
118
+ if method == 0:
119
+ labels.sort(key = lambda x: int(x.index))
120
+
121
+ f = open(dataset_filename, 'r')
122
+ for line in f:
123
+ if labels[i].selected == 1:
124
+ file1.write(line)
125
+ else:
126
+ if split == 1:
127
+ file2.write(line)
128
+ i = i + 1
129
+
130
+ if warning > 0:
131
+ stderr.write("""\
132
+ Warning:
133
+ 1. You may have regression data. Please use -s 1.
134
+ 2. Classification data unbalanced or too small. We select at least 1 per class.
135
+ The subset thus contains %d instances.
136
+ """ % (n+warning))
137
+
138
+ # cleanup
139
+ f.close()
140
+
141
+ file1.close()
142
+
143
+ if split == 1:
144
+ file2.close()
145
+
146
+ main()