eluka 0.1.0
Sign up to get free protection for your applications and to get access to all the features.
- data/.document +5 -0
- data/DOCUMENTATION_STANDARDS +39 -0
- data/Gemfile +13 -0
- data/Gemfile.lock +20 -0
- data/LICENSE.txt +20 -0
- data/README.rdoc +19 -0
- data/Rakefile +69 -0
- data/VERSION +1 -0
- data/examples/example.rb +59 -0
- data/ext/libsvm/COPYRIGHT +31 -0
- data/ext/libsvm/FAQ.html +1749 -0
- data/ext/libsvm/Makefile +25 -0
- data/ext/libsvm/Makefile.win +33 -0
- data/ext/libsvm/README +733 -0
- data/ext/libsvm/extconf.rb +1 -0
- data/ext/libsvm/heart_scale +270 -0
- data/ext/libsvm/java/Makefile +25 -0
- data/ext/libsvm/java/libsvm.jar +0 -0
- data/ext/libsvm/java/libsvm/svm.java +2776 -0
- data/ext/libsvm/java/libsvm/svm.m4 +2776 -0
- data/ext/libsvm/java/libsvm/svm_model.java +21 -0
- data/ext/libsvm/java/libsvm/svm_node.java +6 -0
- data/ext/libsvm/java/libsvm/svm_parameter.java +47 -0
- data/ext/libsvm/java/libsvm/svm_print_interface.java +5 -0
- data/ext/libsvm/java/libsvm/svm_problem.java +7 -0
- data/ext/libsvm/java/svm_predict.java +163 -0
- data/ext/libsvm/java/svm_scale.java +350 -0
- data/ext/libsvm/java/svm_toy.java +471 -0
- data/ext/libsvm/java/svm_train.java +318 -0
- data/ext/libsvm/java/test_applet.html +1 -0
- data/ext/libsvm/python/Makefile +4 -0
- data/ext/libsvm/python/README +331 -0
- data/ext/libsvm/python/svm.py +259 -0
- data/ext/libsvm/python/svmutil.py +242 -0
- data/ext/libsvm/svm-predict.c +226 -0
- data/ext/libsvm/svm-scale.c +353 -0
- data/ext/libsvm/svm-toy/gtk/Makefile +22 -0
- data/ext/libsvm/svm-toy/gtk/callbacks.cpp +423 -0
- data/ext/libsvm/svm-toy/gtk/callbacks.h +54 -0
- data/ext/libsvm/svm-toy/gtk/interface.c +164 -0
- data/ext/libsvm/svm-toy/gtk/interface.h +14 -0
- data/ext/libsvm/svm-toy/gtk/main.c +23 -0
- data/ext/libsvm/svm-toy/gtk/svm-toy.glade +238 -0
- data/ext/libsvm/svm-toy/qt/Makefile +17 -0
- data/ext/libsvm/svm-toy/qt/svm-toy.cpp +413 -0
- data/ext/libsvm/svm-toy/windows/svm-toy.cpp +456 -0
- data/ext/libsvm/svm-train.c +376 -0
- data/ext/libsvm/svm.cpp +3060 -0
- data/ext/libsvm/svm.def +19 -0
- data/ext/libsvm/svm.h +105 -0
- data/ext/libsvm/svm.o +0 -0
- data/ext/libsvm/tools/README +149 -0
- data/ext/libsvm/tools/checkdata.py +108 -0
- data/ext/libsvm/tools/easy.py +79 -0
- data/ext/libsvm/tools/grid.py +359 -0
- data/ext/libsvm/tools/subset.py +146 -0
- data/ext/libsvm/windows/libsvm.dll +0 -0
- data/ext/libsvm/windows/svm-predict.exe +0 -0
- data/ext/libsvm/windows/svm-scale.exe +0 -0
- data/ext/libsvm/windows/svm-toy.exe +0 -0
- data/ext/libsvm/windows/svm-train.exe +0 -0
- data/lib/eluka.rb +10 -0
- data/lib/eluka/bijection.rb +23 -0
- data/lib/eluka/data_point.rb +36 -0
- data/lib/eluka/document.rb +47 -0
- data/lib/eluka/feature_vector.rb +86 -0
- data/lib/eluka/features.rb +31 -0
- data/lib/eluka/model.rb +129 -0
- data/lib/fselect.rb +321 -0
- data/lib/grid.rb +25 -0
- data/test/helper.rb +18 -0
- data/test/test_eluka.rb +7 -0
- metadata +214 -0
@@ -0,0 +1,359 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
|
3
|
+
|
4
|
+
|
5
|
+
import os, sys, traceback
|
6
|
+
import getpass
|
7
|
+
from threading import Thread
|
8
|
+
from subprocess import *
|
9
|
+
|
10
|
+
if(sys.hexversion < 0x03000000):
|
11
|
+
import Queue
|
12
|
+
else:
|
13
|
+
import queue as Queue
|
14
|
+
|
15
|
+
|
16
|
+
# svmtrain and gnuplot executable
|
17
|
+
|
18
|
+
is_win32 = (sys.platform == 'win32')
|
19
|
+
if not is_win32:
|
20
|
+
svmtrain_exe = "../svm-train"
|
21
|
+
gnuplot_exe = "/usr/bin/gnuplot"
|
22
|
+
else:
|
23
|
+
# example for windows
|
24
|
+
svmtrain_exe = r"..\windows\svm-train.exe"
|
25
|
+
gnuplot_exe = r"c:\tmp\gnuplot\bin\pgnuplot.exe"
|
26
|
+
|
27
|
+
# global parameters and their default values
|
28
|
+
|
29
|
+
fold = 5
|
30
|
+
c_begin, c_end, c_step = -5, 15, 2
|
31
|
+
g_begin, g_end, g_step = 3, -15, -2
|
32
|
+
global dataset_pathname, dataset_title, pass_through_string
|
33
|
+
global out_filename, png_filename
|
34
|
+
|
35
|
+
# experimental
|
36
|
+
|
37
|
+
telnet_workers = []
|
38
|
+
ssh_workers = []
|
39
|
+
nr_local_worker = 1
|
40
|
+
|
41
|
+
# process command line options, set global parameters
|
42
|
+
def process_options(argv=sys.argv):
|
43
|
+
|
44
|
+
global fold
|
45
|
+
global c_begin, c_end, c_step
|
46
|
+
global g_begin, g_end, g_step
|
47
|
+
global dataset_pathname, dataset_title, pass_through_string
|
48
|
+
global svmtrain_exe, gnuplot_exe, gnuplot, out_filename, png_filename
|
49
|
+
|
50
|
+
usage = """\
|
51
|
+
Usage: grid.py [-log2c begin,end,step] [-log2g begin,end,step] [-v fold]
|
52
|
+
[-svmtrain pathname] [-gnuplot pathname] [-out pathname] [-png pathname]
|
53
|
+
[additional parameters for svm-train] dataset"""
|
54
|
+
|
55
|
+
if len(argv) < 2:
|
56
|
+
print(usage)
|
57
|
+
sys.exit(1)
|
58
|
+
|
59
|
+
dataset_pathname = argv[-1]
|
60
|
+
dataset_title = os.path.split(dataset_pathname)[1]
|
61
|
+
out_filename = '%s.out' % dataset_title
|
62
|
+
png_filename = '%s.png' % dataset_title
|
63
|
+
pass_through_options = []
|
64
|
+
|
65
|
+
i = 1
|
66
|
+
while i < len(argv) - 1:
|
67
|
+
if argv[i] == "-log2c":
|
68
|
+
i = i + 1
|
69
|
+
(c_begin,c_end,c_step) = map(float,argv[i].split(","))
|
70
|
+
elif argv[i] == "-log2g":
|
71
|
+
i = i + 1
|
72
|
+
(g_begin,g_end,g_step) = map(float,argv[i].split(","))
|
73
|
+
elif argv[i] == "-v":
|
74
|
+
i = i + 1
|
75
|
+
fold = argv[i]
|
76
|
+
elif argv[i] in ('-c','-g'):
|
77
|
+
print("Option -c and -g are renamed.")
|
78
|
+
print(usage)
|
79
|
+
sys.exit(1)
|
80
|
+
elif argv[i] == '-svmtrain':
|
81
|
+
i = i + 1
|
82
|
+
svmtrain_exe = argv[i]
|
83
|
+
elif argv[i] == '-gnuplot':
|
84
|
+
i = i + 1
|
85
|
+
gnuplot_exe = argv[i]
|
86
|
+
elif argv[i] == '-out':
|
87
|
+
i = i + 1
|
88
|
+
out_filename = argv[i]
|
89
|
+
elif argv[i] == '-png':
|
90
|
+
i = i + 1
|
91
|
+
png_filename = argv[i]
|
92
|
+
else:
|
93
|
+
pass_through_options.append(argv[i])
|
94
|
+
i = i + 1
|
95
|
+
|
96
|
+
pass_through_string = " ".join(pass_through_options)
|
97
|
+
assert os.path.exists(svmtrain_exe),"svm-train executable not found"
|
98
|
+
assert os.path.exists(gnuplot_exe),"gnuplot executable not found"
|
99
|
+
assert os.path.exists(dataset_pathname),"dataset not found"
|
100
|
+
gnuplot = Popen(gnuplot_exe,stdin = PIPE).stdin
|
101
|
+
|
102
|
+
|
103
|
+
def range_f(begin,end,step):
|
104
|
+
# like range, but works on non-integer too
|
105
|
+
seq = []
|
106
|
+
while True:
|
107
|
+
if step > 0 and begin > end: break
|
108
|
+
if step < 0 and begin < end: break
|
109
|
+
seq.append(begin)
|
110
|
+
begin = begin + step
|
111
|
+
return seq
|
112
|
+
|
113
|
+
def permute_sequence(seq):
|
114
|
+
n = len(seq)
|
115
|
+
if n <= 1: return seq
|
116
|
+
|
117
|
+
mid = int(n/2)
|
118
|
+
left = permute_sequence(seq[:mid])
|
119
|
+
right = permute_sequence(seq[mid+1:])
|
120
|
+
|
121
|
+
ret = [seq[mid]]
|
122
|
+
while left or right:
|
123
|
+
if left: ret.append(left.pop(0))
|
124
|
+
if right: ret.append(right.pop(0))
|
125
|
+
|
126
|
+
return ret
|
127
|
+
|
128
|
+
def redraw(db,best_param,tofile=False):
|
129
|
+
if len(db) == 0: return
|
130
|
+
begin_level = round(max(x[2] for x in db)) - 3
|
131
|
+
step_size = 0.5
|
132
|
+
|
133
|
+
best_log2c,best_log2g,best_rate = best_param
|
134
|
+
|
135
|
+
if tofile:
|
136
|
+
gnuplot.write( "set term png transparent small\n".encode())
|
137
|
+
gnuplot.write( ("set output \"%s\"\n" % png_filename.replace('\\','\\\\')).encode())
|
138
|
+
#gnuplot.write("set term postscript color solid\n".encode())
|
139
|
+
#gnuplot.write(("set output \"%s.ps\"\n" % dataset_title).encode())
|
140
|
+
elif is_win32:
|
141
|
+
gnuplot.write("set term windows\n".encode())
|
142
|
+
else:
|
143
|
+
gnuplot.write( "set term x11\n".encode())
|
144
|
+
gnuplot.write("set xlabel \"log2(C)\"\n".encode())
|
145
|
+
gnuplot.write("set ylabel \"log2(gamma)\"\n".encode())
|
146
|
+
gnuplot.write(("set xrange [%s:%s]\n" % (c_begin,c_end)).encode())
|
147
|
+
gnuplot.write(("set yrange [%s:%s]\n" % (g_begin,g_end)).encode())
|
148
|
+
gnuplot.write("set contour\n".encode())
|
149
|
+
gnuplot.write(("set cntrparam levels incremental %s,%s,100\n" % (begin_level,step_size)).encode())
|
150
|
+
gnuplot.write("unset surface\n".encode())
|
151
|
+
gnuplot.write("unset ztics\n".encode())
|
152
|
+
gnuplot.write("set view 0,0\n".encode())
|
153
|
+
gnuplot.write(("set title \"%s\"\n" % dataset_title).encode())
|
154
|
+
gnuplot.write("unset label\n".encode())
|
155
|
+
gnuplot.write(("set label \"Best log2(C) = %s log2(gamma) = %s accuracy = %s%%\" \
|
156
|
+
at screen 0.5,0.85 center\n" % \
|
157
|
+
(best_log2c, best_log2g, best_rate)).encode())
|
158
|
+
gnuplot.write(("set label \"C = %s gamma = %s\""
|
159
|
+
" at screen 0.5,0.8 center\n" % (2**best_log2c, 2**best_log2g)).encode())
|
160
|
+
gnuplot.write("splot \"-\" with lines\n".encode())
|
161
|
+
|
162
|
+
|
163
|
+
|
164
|
+
|
165
|
+
db.sort(key = lambda x:(x[0], -x[1]))
|
166
|
+
|
167
|
+
prevc = db[0][0]
|
168
|
+
for line in db:
|
169
|
+
if prevc != line[0]:
|
170
|
+
gnuplot.write("\n".encode())
|
171
|
+
prevc = line[0]
|
172
|
+
gnuplot.write(("%s %s %s\n" % line).encode())
|
173
|
+
gnuplot.write("e\n".encode())
|
174
|
+
gnuplot.write("\n".encode()) # force gnuplot back to prompt when term set failure
|
175
|
+
gnuplot.flush()
|
176
|
+
|
177
|
+
|
178
|
+
def calculate_jobs():
|
179
|
+
c_seq = permute_sequence(range_f(c_begin,c_end,c_step))
|
180
|
+
g_seq = permute_sequence(range_f(g_begin,g_end,g_step))
|
181
|
+
nr_c = float(len(c_seq))
|
182
|
+
nr_g = float(len(g_seq))
|
183
|
+
i = 0
|
184
|
+
j = 0
|
185
|
+
jobs = []
|
186
|
+
|
187
|
+
while i < nr_c or j < nr_g:
|
188
|
+
if i/nr_c < j/nr_g:
|
189
|
+
# increase C resolution
|
190
|
+
line = []
|
191
|
+
for k in range(0,j):
|
192
|
+
line.append((c_seq[i],g_seq[k]))
|
193
|
+
i = i + 1
|
194
|
+
jobs.append(line)
|
195
|
+
else:
|
196
|
+
# increase g resolution
|
197
|
+
line = []
|
198
|
+
for k in range(0,i):
|
199
|
+
line.append((c_seq[k],g_seq[j]))
|
200
|
+
j = j + 1
|
201
|
+
jobs.append(line)
|
202
|
+
return jobs
|
203
|
+
|
204
|
+
class WorkerStopToken: # used to notify the worker to stop
|
205
|
+
pass
|
206
|
+
|
207
|
+
class Worker(Thread):
|
208
|
+
def __init__(self,name,job_queue,result_queue):
|
209
|
+
Thread.__init__(self)
|
210
|
+
self.name = name
|
211
|
+
self.job_queue = job_queue
|
212
|
+
self.result_queue = result_queue
|
213
|
+
def run(self):
|
214
|
+
while True:
|
215
|
+
(cexp,gexp) = self.job_queue.get()
|
216
|
+
if cexp is WorkerStopToken:
|
217
|
+
self.job_queue.put((cexp,gexp))
|
218
|
+
# print 'worker %s stop.' % self.name
|
219
|
+
break
|
220
|
+
try:
|
221
|
+
rate = self.run_one(2.0**cexp,2.0**gexp)
|
222
|
+
if rate is None: raise RuntimeError("get no rate")
|
223
|
+
except:
|
224
|
+
# we failed, let others do that and we just quit
|
225
|
+
|
226
|
+
traceback.print_exception(sys.exc_info()[0], sys.exc_info()[1], sys.exc_info()[2])
|
227
|
+
|
228
|
+
self.job_queue.put((cexp,gexp))
|
229
|
+
print('worker %s quit.' % self.name)
|
230
|
+
break
|
231
|
+
else:
|
232
|
+
self.result_queue.put((self.name,cexp,gexp,rate))
|
233
|
+
|
234
|
+
class LocalWorker(Worker):
|
235
|
+
def run_one(self,c,g):
|
236
|
+
cmdline = '%s -c %s -g %s -v %s %s %s' % \
|
237
|
+
(svmtrain_exe,c,g,fold,pass_through_string,dataset_pathname)
|
238
|
+
result = Popen(cmdline,shell=True,stdout=PIPE).stdout
|
239
|
+
for line in result.readlines():
|
240
|
+
if str(line).find("Cross") != -1:
|
241
|
+
return float(line.split()[-1][0:-1])
|
242
|
+
|
243
|
+
class SSHWorker(Worker):
|
244
|
+
def __init__(self,name,job_queue,result_queue,host):
|
245
|
+
Worker.__init__(self,name,job_queue,result_queue)
|
246
|
+
self.host = host
|
247
|
+
self.cwd = os.getcwd()
|
248
|
+
def run_one(self,c,g):
|
249
|
+
cmdline = 'ssh -x %s "cd %s; %s -c %s -g %s -v %s %s %s"' % \
|
250
|
+
(self.host,self.cwd,
|
251
|
+
svmtrain_exe,c,g,fold,pass_through_string,dataset_pathname)
|
252
|
+
result = Popen(cmdline,shell=True,stdout=PIPE).stdout
|
253
|
+
for line in result.readlines():
|
254
|
+
if str(line).find("Cross") != -1:
|
255
|
+
return float(line.split()[-1][0:-1])
|
256
|
+
|
257
|
+
class TelnetWorker(Worker):
|
258
|
+
def __init__(self,name,job_queue,result_queue,host,username,password):
|
259
|
+
Worker.__init__(self,name,job_queue,result_queue)
|
260
|
+
self.host = host
|
261
|
+
self.username = username
|
262
|
+
self.password = password
|
263
|
+
def run(self):
|
264
|
+
import telnetlib
|
265
|
+
self.tn = tn = telnetlib.Telnet(self.host)
|
266
|
+
tn.read_until("login: ")
|
267
|
+
tn.write(self.username + "\n")
|
268
|
+
tn.read_until("Password: ")
|
269
|
+
tn.write(self.password + "\n")
|
270
|
+
|
271
|
+
# XXX: how to know whether login is successful?
|
272
|
+
tn.read_until(self.username)
|
273
|
+
#
|
274
|
+
print('login ok', self.host)
|
275
|
+
tn.write("cd "+os.getcwd()+"\n")
|
276
|
+
Worker.run(self)
|
277
|
+
tn.write("exit\n")
|
278
|
+
def run_one(self,c,g):
|
279
|
+
cmdline = '%s -c %s -g %s -v %s %s %s' % \
|
280
|
+
(svmtrain_exe,c,g,fold,pass_through_string,dataset_pathname)
|
281
|
+
result = self.tn.write(cmdline+'\n')
|
282
|
+
(idx,matchm,output) = self.tn.expect(['Cross.*\n'])
|
283
|
+
for line in output.split('\n'):
|
284
|
+
if str(line).find("Cross") != -1:
|
285
|
+
return float(line.split()[-1][0:-1])
|
286
|
+
|
287
|
+
def main():
|
288
|
+
|
289
|
+
# set parameters
|
290
|
+
|
291
|
+
process_options()
|
292
|
+
|
293
|
+
# put jobs in queue
|
294
|
+
|
295
|
+
jobs = calculate_jobs()
|
296
|
+
job_queue = Queue.Queue(0)
|
297
|
+
result_queue = Queue.Queue(0)
|
298
|
+
|
299
|
+
for line in jobs:
|
300
|
+
for (c,g) in line:
|
301
|
+
job_queue.put((c,g))
|
302
|
+
|
303
|
+
job_queue._put = job_queue.queue.appendleft
|
304
|
+
|
305
|
+
|
306
|
+
# fire telnet workers
|
307
|
+
|
308
|
+
if telnet_workers:
|
309
|
+
nr_telnet_worker = len(telnet_workers)
|
310
|
+
username = getpass.getuser()
|
311
|
+
password = getpass.getpass()
|
312
|
+
for host in telnet_workers:
|
313
|
+
TelnetWorker(host,job_queue,result_queue,
|
314
|
+
host,username,password).start()
|
315
|
+
|
316
|
+
# fire ssh workers
|
317
|
+
|
318
|
+
if ssh_workers:
|
319
|
+
for host in ssh_workers:
|
320
|
+
SSHWorker(host,job_queue,result_queue,host).start()
|
321
|
+
|
322
|
+
# fire local workers
|
323
|
+
|
324
|
+
for i in range(nr_local_worker):
|
325
|
+
LocalWorker('local',job_queue,result_queue).start()
|
326
|
+
|
327
|
+
# gather results
|
328
|
+
|
329
|
+
done_jobs = {}
|
330
|
+
|
331
|
+
|
332
|
+
result_file = open(out_filename, 'w')
|
333
|
+
|
334
|
+
|
335
|
+
db = []
|
336
|
+
best_rate = -1
|
337
|
+
best_c1,best_g1 = None,None
|
338
|
+
|
339
|
+
for line in jobs:
|
340
|
+
for (c,g) in line:
|
341
|
+
while (c, g) not in done_jobs:
|
342
|
+
(worker,c1,g1,rate) = result_queue.get()
|
343
|
+
done_jobs[(c1,g1)] = rate
|
344
|
+
result_file.write('%s %s %s\n' %(c1,g1,rate))
|
345
|
+
result_file.flush()
|
346
|
+
if (rate > best_rate) or (rate==best_rate and g1==best_g1 and c1<best_c1):
|
347
|
+
best_rate = rate
|
348
|
+
best_c1,best_g1=c1,g1
|
349
|
+
best_c = 2.0**c1
|
350
|
+
best_g = 2.0**g1
|
351
|
+
print("[%s] %s %s %s (best c=%s, g=%s, rate=%s)" % \
|
352
|
+
(worker,c1,g1,rate, best_c, best_g, best_rate))
|
353
|
+
db.append((c,g,done_jobs[(c,g)]))
|
354
|
+
redraw(db,[best_c1, best_g1, best_rate])
|
355
|
+
redraw(db,[best_c1, best_g1, best_rate],True)
|
356
|
+
|
357
|
+
job_queue.put((WorkerStopToken,None))
|
358
|
+
print("%s %s %s" % (best_c, best_g, best_rate))
|
359
|
+
main()
|
@@ -0,0 +1,146 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
from sys import argv, exit, stdout, stderr
|
3
|
+
from random import randint
|
4
|
+
|
5
|
+
method = 0
|
6
|
+
global n
|
7
|
+
global dataset_filename
|
8
|
+
subset_filename = ""
|
9
|
+
rest_filename = ""
|
10
|
+
|
11
|
+
def exit_with_help():
|
12
|
+
print("""\
|
13
|
+
Usage: %s [options] dataset number [output1] [output2]
|
14
|
+
|
15
|
+
This script selects a subset of the given dataset.
|
16
|
+
|
17
|
+
options:
|
18
|
+
-s method : method of selection (default 0)
|
19
|
+
0 -- stratified selection (classification only)
|
20
|
+
1 -- random selection
|
21
|
+
|
22
|
+
output1 : the subset (optional)
|
23
|
+
output2 : rest of the data (optional)
|
24
|
+
If output1 is omitted, the subset will be printed on the screen.""" % argv[0])
|
25
|
+
exit(1)
|
26
|
+
|
27
|
+
def process_options():
|
28
|
+
global method, n
|
29
|
+
global dataset_filename, subset_filename, rest_filename
|
30
|
+
|
31
|
+
argc = len(argv)
|
32
|
+
if argc < 3:
|
33
|
+
exit_with_help()
|
34
|
+
|
35
|
+
i = 1
|
36
|
+
while i < len(argv):
|
37
|
+
if argv[i][0] != "-":
|
38
|
+
break
|
39
|
+
if argv[i] == "-s":
|
40
|
+
i = i + 1
|
41
|
+
method = int(argv[i])
|
42
|
+
if method < 0 or method > 1:
|
43
|
+
print("Unknown selection method %d" % (method))
|
44
|
+
exit_with_help()
|
45
|
+
i = i + 1
|
46
|
+
|
47
|
+
dataset_filename = argv[i]
|
48
|
+
n = int(argv[i+1])
|
49
|
+
if i+2 < argc:
|
50
|
+
subset_filename = argv[i+2]
|
51
|
+
if i+3 < argc:
|
52
|
+
rest_filename = argv[i+3]
|
53
|
+
|
54
|
+
def main():
|
55
|
+
class Label:
|
56
|
+
def __init__(self, label, index, selected):
|
57
|
+
self.label = label
|
58
|
+
self.index = index
|
59
|
+
self.selected = selected
|
60
|
+
|
61
|
+
process_options()
|
62
|
+
|
63
|
+
# get labels
|
64
|
+
i = 0
|
65
|
+
labels = []
|
66
|
+
f = open(dataset_filename, 'r')
|
67
|
+
for line in f:
|
68
|
+
labels.append(Label(float((line.split())[0]), i, 0))
|
69
|
+
i = i + 1
|
70
|
+
f.close()
|
71
|
+
l = i
|
72
|
+
|
73
|
+
# determine where to output
|
74
|
+
if subset_filename != "":
|
75
|
+
file1 = open(subset_filename, 'w')
|
76
|
+
else:
|
77
|
+
file1 = stdout
|
78
|
+
split = 0
|
79
|
+
if rest_filename != "":
|
80
|
+
split = 1
|
81
|
+
file2 = open(rest_filename, 'w')
|
82
|
+
|
83
|
+
# select the subset
|
84
|
+
warning = 0
|
85
|
+
if method == 0: # stratified
|
86
|
+
labels.sort(key = lambda x: x.label)
|
87
|
+
|
88
|
+
label_end = labels[l-1].label + 1
|
89
|
+
labels.append(Label(label_end, l, 0))
|
90
|
+
|
91
|
+
begin = 0
|
92
|
+
label = labels[begin].label
|
93
|
+
for i in range(l+1):
|
94
|
+
new_label = labels[i].label
|
95
|
+
if new_label != label:
|
96
|
+
nr_class = i - begin
|
97
|
+
k = i*n//l - begin*n//l
|
98
|
+
# at least one instance per class
|
99
|
+
if k == 0:
|
100
|
+
k = 1
|
101
|
+
warning = warning + 1
|
102
|
+
for j in range(nr_class):
|
103
|
+
if randint(0, nr_class-j-1) < k:
|
104
|
+
labels[begin+j].selected = 1
|
105
|
+
k = k - 1
|
106
|
+
begin = i
|
107
|
+
label = new_label
|
108
|
+
elif method == 1: # random
|
109
|
+
k = n
|
110
|
+
for i in range(l):
|
111
|
+
if randint(0,l-i-1) < k:
|
112
|
+
labels[i].selected = 1
|
113
|
+
k = k - 1
|
114
|
+
i = i + 1
|
115
|
+
|
116
|
+
# output
|
117
|
+
i = 0
|
118
|
+
if method == 0:
|
119
|
+
labels.sort(key = lambda x: int(x.index))
|
120
|
+
|
121
|
+
f = open(dataset_filename, 'r')
|
122
|
+
for line in f:
|
123
|
+
if labels[i].selected == 1:
|
124
|
+
file1.write(line)
|
125
|
+
else:
|
126
|
+
if split == 1:
|
127
|
+
file2.write(line)
|
128
|
+
i = i + 1
|
129
|
+
|
130
|
+
if warning > 0:
|
131
|
+
stderr.write("""\
|
132
|
+
Warning:
|
133
|
+
1. You may have regression data. Please use -s 1.
|
134
|
+
2. Classification data unbalanced or too small. We select at least 1 per class.
|
135
|
+
The subset thus contains %d instances.
|
136
|
+
""" % (n+warning))
|
137
|
+
|
138
|
+
# cleanup
|
139
|
+
f.close()
|
140
|
+
|
141
|
+
file1.close()
|
142
|
+
|
143
|
+
if split == 1:
|
144
|
+
file2.close()
|
145
|
+
|
146
|
+
main()
|