eluka 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data/.document +5 -0
- data/DOCUMENTATION_STANDARDS +39 -0
- data/Gemfile +13 -0
- data/Gemfile.lock +20 -0
- data/LICENSE.txt +20 -0
- data/README.rdoc +19 -0
- data/Rakefile +69 -0
- data/VERSION +1 -0
- data/examples/example.rb +59 -0
- data/ext/libsvm/COPYRIGHT +31 -0
- data/ext/libsvm/FAQ.html +1749 -0
- data/ext/libsvm/Makefile +25 -0
- data/ext/libsvm/Makefile.win +33 -0
- data/ext/libsvm/README +733 -0
- data/ext/libsvm/extconf.rb +1 -0
- data/ext/libsvm/heart_scale +270 -0
- data/ext/libsvm/java/Makefile +25 -0
- data/ext/libsvm/java/libsvm.jar +0 -0
- data/ext/libsvm/java/libsvm/svm.java +2776 -0
- data/ext/libsvm/java/libsvm/svm.m4 +2776 -0
- data/ext/libsvm/java/libsvm/svm_model.java +21 -0
- data/ext/libsvm/java/libsvm/svm_node.java +6 -0
- data/ext/libsvm/java/libsvm/svm_parameter.java +47 -0
- data/ext/libsvm/java/libsvm/svm_print_interface.java +5 -0
- data/ext/libsvm/java/libsvm/svm_problem.java +7 -0
- data/ext/libsvm/java/svm_predict.java +163 -0
- data/ext/libsvm/java/svm_scale.java +350 -0
- data/ext/libsvm/java/svm_toy.java +471 -0
- data/ext/libsvm/java/svm_train.java +318 -0
- data/ext/libsvm/java/test_applet.html +1 -0
- data/ext/libsvm/python/Makefile +4 -0
- data/ext/libsvm/python/README +331 -0
- data/ext/libsvm/python/svm.py +259 -0
- data/ext/libsvm/python/svmutil.py +242 -0
- data/ext/libsvm/svm-predict.c +226 -0
- data/ext/libsvm/svm-scale.c +353 -0
- data/ext/libsvm/svm-toy/gtk/Makefile +22 -0
- data/ext/libsvm/svm-toy/gtk/callbacks.cpp +423 -0
- data/ext/libsvm/svm-toy/gtk/callbacks.h +54 -0
- data/ext/libsvm/svm-toy/gtk/interface.c +164 -0
- data/ext/libsvm/svm-toy/gtk/interface.h +14 -0
- data/ext/libsvm/svm-toy/gtk/main.c +23 -0
- data/ext/libsvm/svm-toy/gtk/svm-toy.glade +238 -0
- data/ext/libsvm/svm-toy/qt/Makefile +17 -0
- data/ext/libsvm/svm-toy/qt/svm-toy.cpp +413 -0
- data/ext/libsvm/svm-toy/windows/svm-toy.cpp +456 -0
- data/ext/libsvm/svm-train.c +376 -0
- data/ext/libsvm/svm.cpp +3060 -0
- data/ext/libsvm/svm.def +19 -0
- data/ext/libsvm/svm.h +105 -0
- data/ext/libsvm/svm.o +0 -0
- data/ext/libsvm/tools/README +149 -0
- data/ext/libsvm/tools/checkdata.py +108 -0
- data/ext/libsvm/tools/easy.py +79 -0
- data/ext/libsvm/tools/grid.py +359 -0
- data/ext/libsvm/tools/subset.py +146 -0
- data/ext/libsvm/windows/libsvm.dll +0 -0
- data/ext/libsvm/windows/svm-predict.exe +0 -0
- data/ext/libsvm/windows/svm-scale.exe +0 -0
- data/ext/libsvm/windows/svm-toy.exe +0 -0
- data/ext/libsvm/windows/svm-train.exe +0 -0
- data/lib/eluka.rb +10 -0
- data/lib/eluka/bijection.rb +23 -0
- data/lib/eluka/data_point.rb +36 -0
- data/lib/eluka/document.rb +47 -0
- data/lib/eluka/feature_vector.rb +86 -0
- data/lib/eluka/features.rb +31 -0
- data/lib/eluka/model.rb +129 -0
- data/lib/fselect.rb +321 -0
- data/lib/grid.rb +25 -0
- data/test/helper.rb +18 -0
- data/test/test_eluka.rb +7 -0
- metadata +214 -0
@@ -0,0 +1,359 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
|
3
|
+
|
4
|
+
|
5
|
+
import os, sys, traceback
|
6
|
+
import getpass
|
7
|
+
from threading import Thread
|
8
|
+
from subprocess import *
|
9
|
+
|
10
|
+
if(sys.hexversion < 0x03000000):
|
11
|
+
import Queue
|
12
|
+
else:
|
13
|
+
import queue as Queue
|
14
|
+
|
15
|
+
|
16
|
+
# svmtrain and gnuplot executable
|
17
|
+
|
18
|
+
is_win32 = (sys.platform == 'win32')
|
19
|
+
if not is_win32:
|
20
|
+
svmtrain_exe = "../svm-train"
|
21
|
+
gnuplot_exe = "/usr/bin/gnuplot"
|
22
|
+
else:
|
23
|
+
# example for windows
|
24
|
+
svmtrain_exe = r"..\windows\svm-train.exe"
|
25
|
+
gnuplot_exe = r"c:\tmp\gnuplot\bin\pgnuplot.exe"
|
26
|
+
|
27
|
+
# global parameters and their default values
|
28
|
+
|
29
|
+
fold = 5
|
30
|
+
c_begin, c_end, c_step = -5, 15, 2
|
31
|
+
g_begin, g_end, g_step = 3, -15, -2
|
32
|
+
global dataset_pathname, dataset_title, pass_through_string
|
33
|
+
global out_filename, png_filename
|
34
|
+
|
35
|
+
# experimental
|
36
|
+
|
37
|
+
telnet_workers = []
|
38
|
+
ssh_workers = []
|
39
|
+
nr_local_worker = 1
|
40
|
+
|
41
|
+
# process command line options, set global parameters
|
42
|
+
def process_options(argv=sys.argv):
|
43
|
+
|
44
|
+
global fold
|
45
|
+
global c_begin, c_end, c_step
|
46
|
+
global g_begin, g_end, g_step
|
47
|
+
global dataset_pathname, dataset_title, pass_through_string
|
48
|
+
global svmtrain_exe, gnuplot_exe, gnuplot, out_filename, png_filename
|
49
|
+
|
50
|
+
usage = """\
|
51
|
+
Usage: grid.py [-log2c begin,end,step] [-log2g begin,end,step] [-v fold]
|
52
|
+
[-svmtrain pathname] [-gnuplot pathname] [-out pathname] [-png pathname]
|
53
|
+
[additional parameters for svm-train] dataset"""
|
54
|
+
|
55
|
+
if len(argv) < 2:
|
56
|
+
print(usage)
|
57
|
+
sys.exit(1)
|
58
|
+
|
59
|
+
dataset_pathname = argv[-1]
|
60
|
+
dataset_title = os.path.split(dataset_pathname)[1]
|
61
|
+
out_filename = '%s.out' % dataset_title
|
62
|
+
png_filename = '%s.png' % dataset_title
|
63
|
+
pass_through_options = []
|
64
|
+
|
65
|
+
i = 1
|
66
|
+
while i < len(argv) - 1:
|
67
|
+
if argv[i] == "-log2c":
|
68
|
+
i = i + 1
|
69
|
+
(c_begin,c_end,c_step) = map(float,argv[i].split(","))
|
70
|
+
elif argv[i] == "-log2g":
|
71
|
+
i = i + 1
|
72
|
+
(g_begin,g_end,g_step) = map(float,argv[i].split(","))
|
73
|
+
elif argv[i] == "-v":
|
74
|
+
i = i + 1
|
75
|
+
fold = argv[i]
|
76
|
+
elif argv[i] in ('-c','-g'):
|
77
|
+
print("Option -c and -g are renamed.")
|
78
|
+
print(usage)
|
79
|
+
sys.exit(1)
|
80
|
+
elif argv[i] == '-svmtrain':
|
81
|
+
i = i + 1
|
82
|
+
svmtrain_exe = argv[i]
|
83
|
+
elif argv[i] == '-gnuplot':
|
84
|
+
i = i + 1
|
85
|
+
gnuplot_exe = argv[i]
|
86
|
+
elif argv[i] == '-out':
|
87
|
+
i = i + 1
|
88
|
+
out_filename = argv[i]
|
89
|
+
elif argv[i] == '-png':
|
90
|
+
i = i + 1
|
91
|
+
png_filename = argv[i]
|
92
|
+
else:
|
93
|
+
pass_through_options.append(argv[i])
|
94
|
+
i = i + 1
|
95
|
+
|
96
|
+
pass_through_string = " ".join(pass_through_options)
|
97
|
+
assert os.path.exists(svmtrain_exe),"svm-train executable not found"
|
98
|
+
assert os.path.exists(gnuplot_exe),"gnuplot executable not found"
|
99
|
+
assert os.path.exists(dataset_pathname),"dataset not found"
|
100
|
+
gnuplot = Popen(gnuplot_exe,stdin = PIPE).stdin
|
101
|
+
|
102
|
+
|
103
|
+
def range_f(begin,end,step):
|
104
|
+
# like range, but works on non-integer too
|
105
|
+
seq = []
|
106
|
+
while True:
|
107
|
+
if step > 0 and begin > end: break
|
108
|
+
if step < 0 and begin < end: break
|
109
|
+
seq.append(begin)
|
110
|
+
begin = begin + step
|
111
|
+
return seq
|
112
|
+
|
113
|
+
def permute_sequence(seq):
|
114
|
+
n = len(seq)
|
115
|
+
if n <= 1: return seq
|
116
|
+
|
117
|
+
mid = int(n/2)
|
118
|
+
left = permute_sequence(seq[:mid])
|
119
|
+
right = permute_sequence(seq[mid+1:])
|
120
|
+
|
121
|
+
ret = [seq[mid]]
|
122
|
+
while left or right:
|
123
|
+
if left: ret.append(left.pop(0))
|
124
|
+
if right: ret.append(right.pop(0))
|
125
|
+
|
126
|
+
return ret
|
127
|
+
|
128
|
+
def redraw(db,best_param,tofile=False):
|
129
|
+
if len(db) == 0: return
|
130
|
+
begin_level = round(max(x[2] for x in db)) - 3
|
131
|
+
step_size = 0.5
|
132
|
+
|
133
|
+
best_log2c,best_log2g,best_rate = best_param
|
134
|
+
|
135
|
+
if tofile:
|
136
|
+
gnuplot.write( "set term png transparent small\n".encode())
|
137
|
+
gnuplot.write( ("set output \"%s\"\n" % png_filename.replace('\\','\\\\')).encode())
|
138
|
+
#gnuplot.write("set term postscript color solid\n".encode())
|
139
|
+
#gnuplot.write(("set output \"%s.ps\"\n" % dataset_title).encode())
|
140
|
+
elif is_win32:
|
141
|
+
gnuplot.write("set term windows\n".encode())
|
142
|
+
else:
|
143
|
+
gnuplot.write( "set term x11\n".encode())
|
144
|
+
gnuplot.write("set xlabel \"log2(C)\"\n".encode())
|
145
|
+
gnuplot.write("set ylabel \"log2(gamma)\"\n".encode())
|
146
|
+
gnuplot.write(("set xrange [%s:%s]\n" % (c_begin,c_end)).encode())
|
147
|
+
gnuplot.write(("set yrange [%s:%s]\n" % (g_begin,g_end)).encode())
|
148
|
+
gnuplot.write("set contour\n".encode())
|
149
|
+
gnuplot.write(("set cntrparam levels incremental %s,%s,100\n" % (begin_level,step_size)).encode())
|
150
|
+
gnuplot.write("unset surface\n".encode())
|
151
|
+
gnuplot.write("unset ztics\n".encode())
|
152
|
+
gnuplot.write("set view 0,0\n".encode())
|
153
|
+
gnuplot.write(("set title \"%s\"\n" % dataset_title).encode())
|
154
|
+
gnuplot.write("unset label\n".encode())
|
155
|
+
gnuplot.write(("set label \"Best log2(C) = %s log2(gamma) = %s accuracy = %s%%\" \
|
156
|
+
at screen 0.5,0.85 center\n" % \
|
157
|
+
(best_log2c, best_log2g, best_rate)).encode())
|
158
|
+
gnuplot.write(("set label \"C = %s gamma = %s\""
|
159
|
+
" at screen 0.5,0.8 center\n" % (2**best_log2c, 2**best_log2g)).encode())
|
160
|
+
gnuplot.write("splot \"-\" with lines\n".encode())
|
161
|
+
|
162
|
+
|
163
|
+
|
164
|
+
|
165
|
+
db.sort(key = lambda x:(x[0], -x[1]))
|
166
|
+
|
167
|
+
prevc = db[0][0]
|
168
|
+
for line in db:
|
169
|
+
if prevc != line[0]:
|
170
|
+
gnuplot.write("\n".encode())
|
171
|
+
prevc = line[0]
|
172
|
+
gnuplot.write(("%s %s %s\n" % line).encode())
|
173
|
+
gnuplot.write("e\n".encode())
|
174
|
+
gnuplot.write("\n".encode()) # force gnuplot back to prompt when term set failure
|
175
|
+
gnuplot.flush()
|
176
|
+
|
177
|
+
|
178
|
+
def calculate_jobs():
|
179
|
+
c_seq = permute_sequence(range_f(c_begin,c_end,c_step))
|
180
|
+
g_seq = permute_sequence(range_f(g_begin,g_end,g_step))
|
181
|
+
nr_c = float(len(c_seq))
|
182
|
+
nr_g = float(len(g_seq))
|
183
|
+
i = 0
|
184
|
+
j = 0
|
185
|
+
jobs = []
|
186
|
+
|
187
|
+
while i < nr_c or j < nr_g:
|
188
|
+
if i/nr_c < j/nr_g:
|
189
|
+
# increase C resolution
|
190
|
+
line = []
|
191
|
+
for k in range(0,j):
|
192
|
+
line.append((c_seq[i],g_seq[k]))
|
193
|
+
i = i + 1
|
194
|
+
jobs.append(line)
|
195
|
+
else:
|
196
|
+
# increase g resolution
|
197
|
+
line = []
|
198
|
+
for k in range(0,i):
|
199
|
+
line.append((c_seq[k],g_seq[j]))
|
200
|
+
j = j + 1
|
201
|
+
jobs.append(line)
|
202
|
+
return jobs
|
203
|
+
|
204
|
+
class WorkerStopToken: # used to notify the worker to stop
|
205
|
+
pass
|
206
|
+
|
207
|
+
class Worker(Thread):
|
208
|
+
def __init__(self,name,job_queue,result_queue):
|
209
|
+
Thread.__init__(self)
|
210
|
+
self.name = name
|
211
|
+
self.job_queue = job_queue
|
212
|
+
self.result_queue = result_queue
|
213
|
+
def run(self):
|
214
|
+
while True:
|
215
|
+
(cexp,gexp) = self.job_queue.get()
|
216
|
+
if cexp is WorkerStopToken:
|
217
|
+
self.job_queue.put((cexp,gexp))
|
218
|
+
# print 'worker %s stop.' % self.name
|
219
|
+
break
|
220
|
+
try:
|
221
|
+
rate = self.run_one(2.0**cexp,2.0**gexp)
|
222
|
+
if rate is None: raise RuntimeError("get no rate")
|
223
|
+
except:
|
224
|
+
# we failed, let others do that and we just quit
|
225
|
+
|
226
|
+
traceback.print_exception(sys.exc_info()[0], sys.exc_info()[1], sys.exc_info()[2])
|
227
|
+
|
228
|
+
self.job_queue.put((cexp,gexp))
|
229
|
+
print('worker %s quit.' % self.name)
|
230
|
+
break
|
231
|
+
else:
|
232
|
+
self.result_queue.put((self.name,cexp,gexp,rate))
|
233
|
+
|
234
|
+
class LocalWorker(Worker):
|
235
|
+
def run_one(self,c,g):
|
236
|
+
cmdline = '%s -c %s -g %s -v %s %s %s' % \
|
237
|
+
(svmtrain_exe,c,g,fold,pass_through_string,dataset_pathname)
|
238
|
+
result = Popen(cmdline,shell=True,stdout=PIPE).stdout
|
239
|
+
for line in result.readlines():
|
240
|
+
if str(line).find("Cross") != -1:
|
241
|
+
return float(line.split()[-1][0:-1])
|
242
|
+
|
243
|
+
class SSHWorker(Worker):
|
244
|
+
def __init__(self,name,job_queue,result_queue,host):
|
245
|
+
Worker.__init__(self,name,job_queue,result_queue)
|
246
|
+
self.host = host
|
247
|
+
self.cwd = os.getcwd()
|
248
|
+
def run_one(self,c,g):
|
249
|
+
cmdline = 'ssh -x %s "cd %s; %s -c %s -g %s -v %s %s %s"' % \
|
250
|
+
(self.host,self.cwd,
|
251
|
+
svmtrain_exe,c,g,fold,pass_through_string,dataset_pathname)
|
252
|
+
result = Popen(cmdline,shell=True,stdout=PIPE).stdout
|
253
|
+
for line in result.readlines():
|
254
|
+
if str(line).find("Cross") != -1:
|
255
|
+
return float(line.split()[-1][0:-1])
|
256
|
+
|
257
|
+
class TelnetWorker(Worker):
|
258
|
+
def __init__(self,name,job_queue,result_queue,host,username,password):
|
259
|
+
Worker.__init__(self,name,job_queue,result_queue)
|
260
|
+
self.host = host
|
261
|
+
self.username = username
|
262
|
+
self.password = password
|
263
|
+
def run(self):
|
264
|
+
import telnetlib
|
265
|
+
self.tn = tn = telnetlib.Telnet(self.host)
|
266
|
+
tn.read_until("login: ")
|
267
|
+
tn.write(self.username + "\n")
|
268
|
+
tn.read_until("Password: ")
|
269
|
+
tn.write(self.password + "\n")
|
270
|
+
|
271
|
+
# XXX: how to know whether login is successful?
|
272
|
+
tn.read_until(self.username)
|
273
|
+
#
|
274
|
+
print('login ok', self.host)
|
275
|
+
tn.write("cd "+os.getcwd()+"\n")
|
276
|
+
Worker.run(self)
|
277
|
+
tn.write("exit\n")
|
278
|
+
def run_one(self,c,g):
|
279
|
+
cmdline = '%s -c %s -g %s -v %s %s %s' % \
|
280
|
+
(svmtrain_exe,c,g,fold,pass_through_string,dataset_pathname)
|
281
|
+
result = self.tn.write(cmdline+'\n')
|
282
|
+
(idx,matchm,output) = self.tn.expect(['Cross.*\n'])
|
283
|
+
for line in output.split('\n'):
|
284
|
+
if str(line).find("Cross") != -1:
|
285
|
+
return float(line.split()[-1][0:-1])
|
286
|
+
|
287
|
+
def main():
|
288
|
+
|
289
|
+
# set parameters
|
290
|
+
|
291
|
+
process_options()
|
292
|
+
|
293
|
+
# put jobs in queue
|
294
|
+
|
295
|
+
jobs = calculate_jobs()
|
296
|
+
job_queue = Queue.Queue(0)
|
297
|
+
result_queue = Queue.Queue(0)
|
298
|
+
|
299
|
+
for line in jobs:
|
300
|
+
for (c,g) in line:
|
301
|
+
job_queue.put((c,g))
|
302
|
+
|
303
|
+
job_queue._put = job_queue.queue.appendleft
|
304
|
+
|
305
|
+
|
306
|
+
# fire telnet workers
|
307
|
+
|
308
|
+
if telnet_workers:
|
309
|
+
nr_telnet_worker = len(telnet_workers)
|
310
|
+
username = getpass.getuser()
|
311
|
+
password = getpass.getpass()
|
312
|
+
for host in telnet_workers:
|
313
|
+
TelnetWorker(host,job_queue,result_queue,
|
314
|
+
host,username,password).start()
|
315
|
+
|
316
|
+
# fire ssh workers
|
317
|
+
|
318
|
+
if ssh_workers:
|
319
|
+
for host in ssh_workers:
|
320
|
+
SSHWorker(host,job_queue,result_queue,host).start()
|
321
|
+
|
322
|
+
# fire local workers
|
323
|
+
|
324
|
+
for i in range(nr_local_worker):
|
325
|
+
LocalWorker('local',job_queue,result_queue).start()
|
326
|
+
|
327
|
+
# gather results
|
328
|
+
|
329
|
+
done_jobs = {}
|
330
|
+
|
331
|
+
|
332
|
+
result_file = open(out_filename, 'w')
|
333
|
+
|
334
|
+
|
335
|
+
db = []
|
336
|
+
best_rate = -1
|
337
|
+
best_c1,best_g1 = None,None
|
338
|
+
|
339
|
+
for line in jobs:
|
340
|
+
for (c,g) in line:
|
341
|
+
while (c, g) not in done_jobs:
|
342
|
+
(worker,c1,g1,rate) = result_queue.get()
|
343
|
+
done_jobs[(c1,g1)] = rate
|
344
|
+
result_file.write('%s %s %s\n' %(c1,g1,rate))
|
345
|
+
result_file.flush()
|
346
|
+
if (rate > best_rate) or (rate==best_rate and g1==best_g1 and c1<best_c1):
|
347
|
+
best_rate = rate
|
348
|
+
best_c1,best_g1=c1,g1
|
349
|
+
best_c = 2.0**c1
|
350
|
+
best_g = 2.0**g1
|
351
|
+
print("[%s] %s %s %s (best c=%s, g=%s, rate=%s)" % \
|
352
|
+
(worker,c1,g1,rate, best_c, best_g, best_rate))
|
353
|
+
db.append((c,g,done_jobs[(c,g)]))
|
354
|
+
redraw(db,[best_c1, best_g1, best_rate])
|
355
|
+
redraw(db,[best_c1, best_g1, best_rate],True)
|
356
|
+
|
357
|
+
job_queue.put((WorkerStopToken,None))
|
358
|
+
print("%s %s %s" % (best_c, best_g, best_rate))
|
359
|
+
main()
|
@@ -0,0 +1,146 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
from sys import argv, exit, stdout, stderr
|
3
|
+
from random import randint
|
4
|
+
|
5
|
+
method = 0
|
6
|
+
global n
|
7
|
+
global dataset_filename
|
8
|
+
subset_filename = ""
|
9
|
+
rest_filename = ""
|
10
|
+
|
11
|
+
def exit_with_help():
|
12
|
+
print("""\
|
13
|
+
Usage: %s [options] dataset number [output1] [output2]
|
14
|
+
|
15
|
+
This script selects a subset of the given dataset.
|
16
|
+
|
17
|
+
options:
|
18
|
+
-s method : method of selection (default 0)
|
19
|
+
0 -- stratified selection (classification only)
|
20
|
+
1 -- random selection
|
21
|
+
|
22
|
+
output1 : the subset (optional)
|
23
|
+
output2 : rest of the data (optional)
|
24
|
+
If output1 is omitted, the subset will be printed on the screen.""" % argv[0])
|
25
|
+
exit(1)
|
26
|
+
|
27
|
+
def process_options():
|
28
|
+
global method, n
|
29
|
+
global dataset_filename, subset_filename, rest_filename
|
30
|
+
|
31
|
+
argc = len(argv)
|
32
|
+
if argc < 3:
|
33
|
+
exit_with_help()
|
34
|
+
|
35
|
+
i = 1
|
36
|
+
while i < len(argv):
|
37
|
+
if argv[i][0] != "-":
|
38
|
+
break
|
39
|
+
if argv[i] == "-s":
|
40
|
+
i = i + 1
|
41
|
+
method = int(argv[i])
|
42
|
+
if method < 0 or method > 1:
|
43
|
+
print("Unknown selection method %d" % (method))
|
44
|
+
exit_with_help()
|
45
|
+
i = i + 1
|
46
|
+
|
47
|
+
dataset_filename = argv[i]
|
48
|
+
n = int(argv[i+1])
|
49
|
+
if i+2 < argc:
|
50
|
+
subset_filename = argv[i+2]
|
51
|
+
if i+3 < argc:
|
52
|
+
rest_filename = argv[i+3]
|
53
|
+
|
54
|
+
def main():
|
55
|
+
class Label:
|
56
|
+
def __init__(self, label, index, selected):
|
57
|
+
self.label = label
|
58
|
+
self.index = index
|
59
|
+
self.selected = selected
|
60
|
+
|
61
|
+
process_options()
|
62
|
+
|
63
|
+
# get labels
|
64
|
+
i = 0
|
65
|
+
labels = []
|
66
|
+
f = open(dataset_filename, 'r')
|
67
|
+
for line in f:
|
68
|
+
labels.append(Label(float((line.split())[0]), i, 0))
|
69
|
+
i = i + 1
|
70
|
+
f.close()
|
71
|
+
l = i
|
72
|
+
|
73
|
+
# determine where to output
|
74
|
+
if subset_filename != "":
|
75
|
+
file1 = open(subset_filename, 'w')
|
76
|
+
else:
|
77
|
+
file1 = stdout
|
78
|
+
split = 0
|
79
|
+
if rest_filename != "":
|
80
|
+
split = 1
|
81
|
+
file2 = open(rest_filename, 'w')
|
82
|
+
|
83
|
+
# select the subset
|
84
|
+
warning = 0
|
85
|
+
if method == 0: # stratified
|
86
|
+
labels.sort(key = lambda x: x.label)
|
87
|
+
|
88
|
+
label_end = labels[l-1].label + 1
|
89
|
+
labels.append(Label(label_end, l, 0))
|
90
|
+
|
91
|
+
begin = 0
|
92
|
+
label = labels[begin].label
|
93
|
+
for i in range(l+1):
|
94
|
+
new_label = labels[i].label
|
95
|
+
if new_label != label:
|
96
|
+
nr_class = i - begin
|
97
|
+
k = i*n//l - begin*n//l
|
98
|
+
# at least one instance per class
|
99
|
+
if k == 0:
|
100
|
+
k = 1
|
101
|
+
warning = warning + 1
|
102
|
+
for j in range(nr_class):
|
103
|
+
if randint(0, nr_class-j-1) < k:
|
104
|
+
labels[begin+j].selected = 1
|
105
|
+
k = k - 1
|
106
|
+
begin = i
|
107
|
+
label = new_label
|
108
|
+
elif method == 1: # random
|
109
|
+
k = n
|
110
|
+
for i in range(l):
|
111
|
+
if randint(0,l-i-1) < k:
|
112
|
+
labels[i].selected = 1
|
113
|
+
k = k - 1
|
114
|
+
i = i + 1
|
115
|
+
|
116
|
+
# output
|
117
|
+
i = 0
|
118
|
+
if method == 0:
|
119
|
+
labels.sort(key = lambda x: int(x.index))
|
120
|
+
|
121
|
+
f = open(dataset_filename, 'r')
|
122
|
+
for line in f:
|
123
|
+
if labels[i].selected == 1:
|
124
|
+
file1.write(line)
|
125
|
+
else:
|
126
|
+
if split == 1:
|
127
|
+
file2.write(line)
|
128
|
+
i = i + 1
|
129
|
+
|
130
|
+
if warning > 0:
|
131
|
+
stderr.write("""\
|
132
|
+
Warning:
|
133
|
+
1. You may have regression data. Please use -s 1.
|
134
|
+
2. Classification data unbalanced or too small. We select at least 1 per class.
|
135
|
+
The subset thus contains %d instances.
|
136
|
+
""" % (n+warning))
|
137
|
+
|
138
|
+
# cleanup
|
139
|
+
f.close()
|
140
|
+
|
141
|
+
file1.close()
|
142
|
+
|
143
|
+
if split == 1:
|
144
|
+
file2.close()
|
145
|
+
|
146
|
+
main()
|