CooCoo 0.1.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/.gitignore +16 -0
- data/CooCoo.gemspec +47 -0
- data/Gemfile +4 -0
- data/Gemfile.lock +88 -0
- data/README.md +123 -0
- data/Rakefile +81 -0
- data/bin/cuda-dev-info +25 -0
- data/bin/cuda-free +28 -0
- data/bin/cuda-free-trend +7 -0
- data/bin/ffi-gen +267 -0
- data/bin/spec_runner_html.sh +42 -0
- data/bin/trainer +198 -0
- data/bin/trend-cost +13 -0
- data/examples/char-rnn.rb +405 -0
- data/examples/cifar/cifar.rb +94 -0
- data/examples/img-similarity.rb +201 -0
- data/examples/math_ops.rb +57 -0
- data/examples/mnist.rb +365 -0
- data/examples/mnist_classifier.rb +293 -0
- data/examples/mnist_dream.rb +214 -0
- data/examples/seeds.rb +268 -0
- data/examples/seeds_dataset.txt +210 -0
- data/examples/t10k-images-idx3-ubyte +0 -0
- data/examples/t10k-labels-idx1-ubyte +0 -0
- data/examples/train-images-idx3-ubyte +0 -0
- data/examples/train-labels-idx1-ubyte +0 -0
- data/ext/buffer/Rakefile +50 -0
- data/ext/buffer/buffer.pre.cu +727 -0
- data/ext/buffer/matrix.pre.cu +49 -0
- data/lib/CooCoo.rb +1 -0
- data/lib/coo-coo.rb +18 -0
- data/lib/coo-coo/activation_functions.rb +344 -0
- data/lib/coo-coo/consts.rb +5 -0
- data/lib/coo-coo/convolution.rb +298 -0
- data/lib/coo-coo/core_ext.rb +75 -0
- data/lib/coo-coo/cost_functions.rb +91 -0
- data/lib/coo-coo/cuda.rb +116 -0
- data/lib/coo-coo/cuda/device_buffer.rb +240 -0
- data/lib/coo-coo/cuda/device_buffer/ffi.rb +109 -0
- data/lib/coo-coo/cuda/error.rb +51 -0
- data/lib/coo-coo/cuda/host_buffer.rb +117 -0
- data/lib/coo-coo/cuda/runtime.rb +157 -0
- data/lib/coo-coo/cuda/vector.rb +315 -0
- data/lib/coo-coo/data_sources.rb +2 -0
- data/lib/coo-coo/data_sources/xournal.rb +25 -0
- data/lib/coo-coo/data_sources/xournal/bitmap_stream.rb +197 -0
- data/lib/coo-coo/data_sources/xournal/document.rb +377 -0
- data/lib/coo-coo/data_sources/xournal/loader.rb +144 -0
- data/lib/coo-coo/data_sources/xournal/renderer.rb +101 -0
- data/lib/coo-coo/data_sources/xournal/saver.rb +99 -0
- data/lib/coo-coo/data_sources/xournal/training_document.rb +78 -0
- data/lib/coo-coo/data_sources/xournal/training_document/constants.rb +15 -0
- data/lib/coo-coo/data_sources/xournal/training_document/document_maker.rb +89 -0
- data/lib/coo-coo/data_sources/xournal/training_document/document_reader.rb +105 -0
- data/lib/coo-coo/data_sources/xournal/training_document/example.rb +37 -0
- data/lib/coo-coo/data_sources/xournal/training_document/sets.rb +76 -0
- data/lib/coo-coo/debug.rb +8 -0
- data/lib/coo-coo/dot.rb +129 -0
- data/lib/coo-coo/drawing.rb +4 -0
- data/lib/coo-coo/drawing/cairo_canvas.rb +100 -0
- data/lib/coo-coo/drawing/canvas.rb +68 -0
- data/lib/coo-coo/drawing/chunky_canvas.rb +101 -0
- data/lib/coo-coo/drawing/sixel.rb +214 -0
- data/lib/coo-coo/enum.rb +17 -0
- data/lib/coo-coo/from_name.rb +58 -0
- data/lib/coo-coo/fully_connected_layer.rb +205 -0
- data/lib/coo-coo/generation_script.rb +38 -0
- data/lib/coo-coo/grapher.rb +140 -0
- data/lib/coo-coo/image.rb +286 -0
- data/lib/coo-coo/layer.rb +67 -0
- data/lib/coo-coo/layer_factory.rb +26 -0
- data/lib/coo-coo/linear_layer.rb +59 -0
- data/lib/coo-coo/math.rb +607 -0
- data/lib/coo-coo/math/abstract_vector.rb +121 -0
- data/lib/coo-coo/math/functions.rb +39 -0
- data/lib/coo-coo/math/interpolation.rb +7 -0
- data/lib/coo-coo/network.rb +264 -0
- data/lib/coo-coo/neuron.rb +112 -0
- data/lib/coo-coo/neuron_layer.rb +168 -0
- data/lib/coo-coo/option_parser.rb +18 -0
- data/lib/coo-coo/platform.rb +17 -0
- data/lib/coo-coo/progress_bar.rb +11 -0
- data/lib/coo-coo/recurrence/backend.rb +99 -0
- data/lib/coo-coo/recurrence/frontend.rb +101 -0
- data/lib/coo-coo/sequence.rb +187 -0
- data/lib/coo-coo/shell.rb +2 -0
- data/lib/coo-coo/temporal_network.rb +291 -0
- data/lib/coo-coo/trainer.rb +21 -0
- data/lib/coo-coo/trainer/base.rb +67 -0
- data/lib/coo-coo/trainer/batch.rb +82 -0
- data/lib/coo-coo/trainer/batch_stats.rb +27 -0
- data/lib/coo-coo/trainer/momentum_stochastic.rb +59 -0
- data/lib/coo-coo/trainer/stochastic.rb +47 -0
- data/lib/coo-coo/transformer.rb +272 -0
- data/lib/coo-coo/vector_layer.rb +194 -0
- data/lib/coo-coo/version.rb +3 -0
- data/lib/coo-coo/weight_deltas.rb +23 -0
- data/prototypes/convolution.rb +116 -0
- data/prototypes/linear_drop.rb +51 -0
- data/prototypes/recurrent_layers.rb +79 -0
- data/www/images/screamer.png +0 -0
- data/www/images/screamer.xcf +0 -0
- data/www/index.html +82 -0
- metadata +373 -0
@@ -0,0 +1,42 @@
|
|
1
|
+
#!/bin/zsh
|
2
|
+
|
3
|
+
INDEX=doc/spec/index.html
|
4
|
+
|
5
|
+
# Start the index.html
|
6
|
+
cat <<EOT > $INDEX
|
7
|
+
<html>
|
8
|
+
<head>
|
9
|
+
<style>.PASSED { color: green; } .FAILED { color: red; }</style>
|
10
|
+
</head>
|
11
|
+
<body>
|
12
|
+
<ul>
|
13
|
+
EOT
|
14
|
+
|
15
|
+
# The spec runner
|
16
|
+
function run_spec()
|
17
|
+
{
|
18
|
+
local spec="$1"
|
19
|
+
local DIR=`dirname $spec`
|
20
|
+
local BASE=`basename -s .spec $spec`
|
21
|
+
echo -n "$spec "
|
22
|
+
|
23
|
+
mkdir -p doc/$DIR && bundle exec rspec -Ilib -Iexamples $spec -f html > doc/$DIR/$BASE.html && STATE=PASSED || STATE=FAILED
|
24
|
+
|
25
|
+
echo $STATE
|
26
|
+
|
27
|
+
cat <<-EOT >> $INDEX
|
28
|
+
<li class="$STATE"><a href="../$DIR/$BASE.html">$DIR/$BASE</a> — <span class="state">$STATE</span></li>
|
29
|
+
EOT
|
30
|
+
}
|
31
|
+
|
32
|
+
# And the loop
|
33
|
+
for spec in `find spec -name \*.spec | sort`; do
|
34
|
+
run_spec "$spec"
|
35
|
+
done
|
36
|
+
|
37
|
+
# Finalize the index
|
38
|
+
cat <<EOT >> $INDEX
|
39
|
+
</ul>
|
40
|
+
</body>
|
41
|
+
</html>
|
42
|
+
EOT
|
data/bin/trainer
ADDED
@@ -0,0 +1,198 @@
|
|
1
|
+
#!/bin/env ruby
|
2
|
+
|
3
|
+
ROOT = File.dirname(File.dirname(__FILE__))
|
4
|
+
$: << File.join(ROOT, 'lib') << File.join(ROOT, 'examples')
|
5
|
+
|
6
|
+
require 'colorize'
|
7
|
+
require 'coo-coo'
|
8
|
+
|
9
|
+
def load_script_opts(path)
|
10
|
+
CooCoo::GenerationScript.
|
11
|
+
new(path, $stdout).
|
12
|
+
opts
|
13
|
+
end
|
14
|
+
|
15
|
+
require 'pry'
|
16
|
+
$pry = {
|
17
|
+
binding: binding,
|
18
|
+
main_thread: Thread.current
|
19
|
+
}
|
20
|
+
|
21
|
+
Signal.trap("USR1") do
|
22
|
+
$pry[:thread] ||= Thread.new do
|
23
|
+
$pry[:binding].pry
|
24
|
+
end
|
25
|
+
end
|
26
|
+
|
27
|
+
require 'ostruct'
|
28
|
+
options = OpenStruct.new
|
29
|
+
options.epochs = 1000
|
30
|
+
options.dataset = nil
|
31
|
+
options.model_path = nil
|
32
|
+
options.prototype = nil
|
33
|
+
options.trainer = nil
|
34
|
+
options.num_tests = 20
|
35
|
+
options.start_tests_at = 0
|
36
|
+
options.test_cost = CooCoo::CostFunctions::MeanSquare
|
37
|
+
|
38
|
+
opts = CooCoo::OptionParser.new do |o|
|
39
|
+
o.on('-d', '--dataset PATH') do |path|
|
40
|
+
options.dataset = path
|
41
|
+
end
|
42
|
+
|
43
|
+
o.on('-m', '--model PATH') do |path|
|
44
|
+
options.model_path = Pathname.new(path)
|
45
|
+
options.binary_blob = options.binary_blob || File.extname(options.model_path) == '.bin'
|
46
|
+
end
|
47
|
+
|
48
|
+
o.on('--prototype PATH') do |path|
|
49
|
+
options.prototype = path
|
50
|
+
end
|
51
|
+
|
52
|
+
o.on('--binary') do
|
53
|
+
options.binary_blob = true
|
54
|
+
end
|
55
|
+
|
56
|
+
o.on('-e', '--train NUMBER') do |n|
|
57
|
+
options.epochs = n.to_i
|
58
|
+
end
|
59
|
+
|
60
|
+
o.on('-p', '--predict NUMBER') do |n|
|
61
|
+
options.num_tests = n.to_i
|
62
|
+
end
|
63
|
+
|
64
|
+
o.on('-s', '--skip NUMBER') do |n|
|
65
|
+
options.start_tests_at = n.to_i
|
66
|
+
end
|
67
|
+
|
68
|
+
o.on('-t', '--trainer NAME') do |name|
|
69
|
+
options.trainer = name
|
70
|
+
end
|
71
|
+
|
72
|
+
o.on('--test-cost NAME') do |name|
|
73
|
+
options.test_cost = CooCoo::CostFunctions.from_name(name)
|
74
|
+
end
|
75
|
+
|
76
|
+
o.on('--verbose') do
|
77
|
+
options.verbose = true
|
78
|
+
end
|
79
|
+
|
80
|
+
o.on('-h', '--help') do
|
81
|
+
puts(opts)
|
82
|
+
if options.dataset
|
83
|
+
opts = load_script_opts(options.dataset.to_s)
|
84
|
+
puts(opts)
|
85
|
+
end
|
86
|
+
if options.prototype
|
87
|
+
opts = load_script_opts(options.prototype.to_s)
|
88
|
+
puts(opts)
|
89
|
+
end
|
90
|
+
if options.trainer
|
91
|
+
t = CooCoo::Trainer.from_name(options.trainer)
|
92
|
+
raise ArgumentError.new("Unknown trainer #{options.trainer}") unless t
|
93
|
+
opts, _ = t.options
|
94
|
+
puts(opts)
|
95
|
+
end
|
96
|
+
exit
|
97
|
+
end
|
98
|
+
end
|
99
|
+
|
100
|
+
argv = opts.parse!(ARGV)
|
101
|
+
|
102
|
+
puts("Loading training set #{options.dataset}")
|
103
|
+
|
104
|
+
training_set_gen = CooCoo::GenerationScript.new(options.dataset.to_s, $stdout)
|
105
|
+
argv, training_set = training_set_gen.call(argv)
|
106
|
+
|
107
|
+
if options.model_path && File.exists?(options.model_path.to_s)
|
108
|
+
puts("Loading network #{options.model_path}")
|
109
|
+
if options.binary_blob
|
110
|
+
net = Marshal.load(File.read(options.model_path))
|
111
|
+
else
|
112
|
+
net = CooCoo::Network.load(options.model_path)
|
113
|
+
end
|
114
|
+
else
|
115
|
+
puts("Generating network from #{options.prototype}")
|
116
|
+
net_gen = CooCoo::GenerationScript.new(options.prototype.to_s, $stdout)
|
117
|
+
argv, net = net_gen.call(argv, training_set)
|
118
|
+
end
|
119
|
+
|
120
|
+
trainer = nil
|
121
|
+
trainer_options = nil
|
122
|
+
|
123
|
+
if options.trainer
|
124
|
+
trainer = CooCoo::Trainer.from_name(options.trainer)
|
125
|
+
t_opts, trainer_options = trainer.options
|
126
|
+
argv = t_opts.parse!(argv)
|
127
|
+
end
|
128
|
+
|
129
|
+
unless argv.empty?
|
130
|
+
raise ArgumentError.new("Unknown arguments: #{argv.inspect}")
|
131
|
+
end
|
132
|
+
|
133
|
+
puts("Net ready:")
|
134
|
+
puts("\tAge: #{net.age}")
|
135
|
+
puts("\tActivation: #{net.activation_function}")
|
136
|
+
puts("\tInputs: #{net.num_inputs}")
|
137
|
+
puts("\tOutputs: #{net.num_outputs}")
|
138
|
+
puts("\tLayers: #{net.num_layers}")
|
139
|
+
net.layers.each_with_index do |l, i|
|
140
|
+
puts("\t\t#{i}\t#{l.num_inputs}\t#{l.size}\t#{l.class}")
|
141
|
+
end
|
142
|
+
|
143
|
+
$stdout.flush
|
144
|
+
|
145
|
+
if options.trainer
|
146
|
+
num_batches = options.epochs.to_i * training_set.size / trainer_options.batch_size
|
147
|
+
|
148
|
+
puts("Training in #{num_batches} batches of #{trainer_options.batch_size} examples for #{options.epochs} epochs at a rate of #{trainer_options.learning_rate} with #{trainer.name} using #{trainer_options.cost_function.name}.")
|
149
|
+
|
150
|
+
bar = CooCoo::ProgressBar.create(:total => num_batches)
|
151
|
+
|
152
|
+
trainer.train({ network: net,
|
153
|
+
data: training_set.each.cycle(options.epochs)
|
154
|
+
}.merge(trainer_options.to_h)) do |stats|
|
155
|
+
avg_err = stats.average_loss
|
156
|
+
raise "Cost went to NAN" if avg_err.nan?
|
157
|
+
status = []
|
158
|
+
status << "Cost\t#{avg_err.average}\t#{options.verbose ? avg_err : nil}"
|
159
|
+
status << "Batch #{stats.batch}/#{num_batches} took #{stats.total_time} seconds"
|
160
|
+
if options.model_path
|
161
|
+
if options.binary_blob
|
162
|
+
File.write_to(options.model_path) do |f|
|
163
|
+
f.write(Marshal.dump(net))
|
164
|
+
end
|
165
|
+
else
|
166
|
+
net.save(options.model_path)
|
167
|
+
end
|
168
|
+
status << "Saved network to #{options.model_path}"
|
169
|
+
end
|
170
|
+
bar.log(status.join("\n"))
|
171
|
+
bar.increment
|
172
|
+
end
|
173
|
+
end
|
174
|
+
|
175
|
+
CHECKMARK = "\u2714".green
|
176
|
+
CROSSMARK = "\u2718".red
|
177
|
+
|
178
|
+
if options.num_tests
|
179
|
+
puts("Running #{options.num_tests} tests starting from #{options.start_tests_at}:")
|
180
|
+
e = training_set.each
|
181
|
+
e = e.drop(options.start_tests_at) if options.start_tests_at > 0
|
182
|
+
passes = 0
|
183
|
+
e.first(options.num_tests).
|
184
|
+
each_with_index do |(target, pixels), i|
|
185
|
+
out, hs = net.predict(pixels)
|
186
|
+
loss = options.test_cost.call(target, out)
|
187
|
+
target_max = target.each.with_index.max
|
188
|
+
out_max = out.each.with_index.max
|
189
|
+
passed = target_max[1] == out_max[1]
|
190
|
+
passes += 1 if passed
|
191
|
+
puts("#{passed ? CHECKMARK : CROSSMARK} #{options.start_tests_at + i}\t#{loss.average}\t#{target_max}\t#{out[target_max[1]]}\t#{out_max}\t#{target}\t#{out}")
|
192
|
+
$stdout.flush
|
193
|
+
GC.start
|
194
|
+
end
|
195
|
+
|
196
|
+
puts
|
197
|
+
puts("#{passes}/#{options.num_tests} #{passes / options.num_tests.to_f * 100.0}% passed")
|
198
|
+
end
|
data/bin/trend-cost
ADDED
@@ -0,0 +1,13 @@
|
|
1
|
+
#!/bin/zsh
|
2
|
+
|
3
|
+
function run()
|
4
|
+
{
|
5
|
+
stdbuf -oL $*
|
6
|
+
}
|
7
|
+
|
8
|
+
tee >(run grep -i cost |
|
9
|
+
run sed -e 's: :\t:g' |
|
10
|
+
run cut -f 2 |
|
11
|
+
trend -t "${TREND_TITLE:-Cost}" -c 1a -geometry ${TREND_GEOMETRY:-320x64-0-24} - ${TREND_SIZE:-$((60*4*5))x2} ${TREND_SCALE})
|
12
|
+
|
13
|
+
[[ -z `jobs -p` ]] || kill $(jobs -p)
|
@@ -0,0 +1,405 @@
|
|
1
|
+
require 'coo-coo'
|
2
|
+
|
3
|
+
class InputEncoder
|
4
|
+
protected
|
5
|
+
def initialize
|
6
|
+
end
|
7
|
+
|
8
|
+
public
|
9
|
+
def vector_size
|
10
|
+
raise NotImplementedError
|
11
|
+
end
|
12
|
+
|
13
|
+
def encode_input(s)
|
14
|
+
raise NotImplementedError
|
15
|
+
end
|
16
|
+
|
17
|
+
def decode_byte(b)
|
18
|
+
raise NotImplementedError
|
19
|
+
end
|
20
|
+
|
21
|
+
def decode_output(b)
|
22
|
+
raise NotImplementedError
|
23
|
+
end
|
24
|
+
|
25
|
+
def encode_string(s)
|
26
|
+
s.bytes.collect { |b| encode_input(b) }
|
27
|
+
end
|
28
|
+
|
29
|
+
def decode_sequence(s)
|
30
|
+
s.pack('c*')
|
31
|
+
end
|
32
|
+
|
33
|
+
def decode_to_string(output)
|
34
|
+
decode_sequence(output.collect { |v| decode_output(v) })
|
35
|
+
end
|
36
|
+
end
|
37
|
+
|
38
|
+
class LittleInputEncoder < InputEncoder
|
39
|
+
UA = 'A'.bytes[0]
|
40
|
+
UZ = 'Z'.bytes[0]
|
41
|
+
LA = 'a'.bytes[0]
|
42
|
+
LZ = 'z'.bytes[0]
|
43
|
+
N0 = '0'.bytes[0]
|
44
|
+
N9 = '9'.bytes[0]
|
45
|
+
SPACE = ' '.bytes[0]
|
46
|
+
|
47
|
+
def vector_size
|
48
|
+
39
|
49
|
+
end
|
50
|
+
|
51
|
+
def encode_byte(b)
|
52
|
+
if b >= UA && b <= UZ
|
53
|
+
return (b - UA) + 2
|
54
|
+
elsif b >= LA && b <= LZ
|
55
|
+
return (b - LA) + 2
|
56
|
+
elsif b >= N0 && b <= N9
|
57
|
+
return (b - N0) + 26 + 2
|
58
|
+
elsif b == SPACE
|
59
|
+
return 1
|
60
|
+
else
|
61
|
+
return 0
|
62
|
+
end
|
63
|
+
end
|
64
|
+
|
65
|
+
def decode_byte(i)
|
66
|
+
if i <= 1
|
67
|
+
return SPACE
|
68
|
+
elsif i <= 27
|
69
|
+
return LA + (i - 2)
|
70
|
+
elsif i <= 37
|
71
|
+
return N0 + (i - 28)
|
72
|
+
else
|
73
|
+
return SPACE
|
74
|
+
end
|
75
|
+
end
|
76
|
+
|
77
|
+
def encode_input(b)
|
78
|
+
v = CooCoo::Vector.zeros(vector_size)
|
79
|
+
v[encode_byte(b)] = 1.0
|
80
|
+
v
|
81
|
+
end
|
82
|
+
|
83
|
+
def decode_output(v)
|
84
|
+
v, i = v.each_with_index.max
|
85
|
+
decode_byte(i)
|
86
|
+
end
|
87
|
+
end
|
88
|
+
|
89
|
+
class AsciiInputEncoder < InputEncoder
|
90
|
+
def vector_size
|
91
|
+
256
|
92
|
+
end
|
93
|
+
|
94
|
+
def encode_input(b)
|
95
|
+
$encoded_input_hash ||= Hash.new do |h, k|
|
96
|
+
v = CooCoo::Vector.zeros(vector_size)
|
97
|
+
v[k] = 1.0
|
98
|
+
h[k] = v
|
99
|
+
end
|
100
|
+
|
101
|
+
$encoded_input_hash[b]
|
102
|
+
end
|
103
|
+
|
104
|
+
def decode_output(v)
|
105
|
+
$encoded_output_hash ||= Hash.new do |h, k|
|
106
|
+
i = k.each_with_index.max[1]
|
107
|
+
h[k] = i
|
108
|
+
end
|
109
|
+
|
110
|
+
$encoded_output_hash[v]
|
111
|
+
end
|
112
|
+
end
|
113
|
+
|
114
|
+
def training_enumerator(data, sequence_size, encoder)
|
115
|
+
Enumerator.new do |yielder|
|
116
|
+
data.size.times do |i|
|
117
|
+
input = data[i, sequence_size].collect { |e| encoder.encode_input(e || 0) }
|
118
|
+
output = data[i + 1, sequence_size].collect { |e| encoder.encode_input(e || 0) }
|
119
|
+
yielder << [ CooCoo::Sequence[output], CooCoo::Sequence[input] ]
|
120
|
+
end
|
121
|
+
# iters = sequence_size.times.collect { |i| data.each.drop(i) }
|
122
|
+
# iters[0].zip(*iters.drop(1)).
|
123
|
+
# each_with_index do |values, i|
|
124
|
+
# input = values[0, values.size - 1].collect { |e| encoder.encode_input(e || 0) }
|
125
|
+
# output = values[1, values.size - 1].collect { |e| encoder.encode_input(e || 0) }
|
126
|
+
# yielder << [ CooCoo::Sequence[output], CooCoo::Sequence[input] ]
|
127
|
+
# end
|
128
|
+
end
|
129
|
+
end
|
130
|
+
|
131
|
+
if __FILE__ == $0
|
132
|
+
def sample_top(arr, range)
|
133
|
+
c = arr.each.with_index.sort[-range, range.abs].collect(&:last)
|
134
|
+
c[rand(c.size)]
|
135
|
+
end
|
136
|
+
|
137
|
+
def sample(arr, temperature = 1.0)
|
138
|
+
narr = arr.normalize
|
139
|
+
picks = (CooCoo::Vector.rand(arr.size) - narr).each.with_index.select { |v, i| v <= 0.0 }.sort
|
140
|
+
pick = picks[rand(picks.size) * temperature]
|
141
|
+
(pick && pick[1]) || 0
|
142
|
+
end
|
143
|
+
|
144
|
+
require 'ostruct'
|
145
|
+
|
146
|
+
options = OpenStruct.new
|
147
|
+
options.encoder = AsciiInputEncoder.new
|
148
|
+
options.recurrent_size = 1024
|
149
|
+
options.activation_function = CooCoo::ActivationFunctions.from_name('Logistic')
|
150
|
+
options.epochs = 1000
|
151
|
+
options.model_path = "char-rnn.coo-coo_model"
|
152
|
+
options.input_path = nil
|
153
|
+
options.backprop_limit = nil
|
154
|
+
options.trainer = nil
|
155
|
+
options.sequence_size = 4
|
156
|
+
options.num_layers = 1
|
157
|
+
options.hidden_size = nil
|
158
|
+
options.num_recurrent_layers = 2
|
159
|
+
options.softmax = nil
|
160
|
+
options.verbose = false
|
161
|
+
options.generator = false
|
162
|
+
options.generator_temperature = 4
|
163
|
+
options.generator_amount = 140
|
164
|
+
options.generator_init = "\n"
|
165
|
+
options.sampler = method(:sample)
|
166
|
+
|
167
|
+
opts = CooCoo::OptionParser.new do |o|
|
168
|
+
o.on('-v', '--verbose') do
|
169
|
+
options.verbose = true
|
170
|
+
end
|
171
|
+
|
172
|
+
o.on('--little') do |v|
|
173
|
+
options.encoder = LittleInputEncoder.new
|
174
|
+
end
|
175
|
+
|
176
|
+
o.on('-m', '--model PATH') do |path|
|
177
|
+
options.model_path = path
|
178
|
+
end
|
179
|
+
|
180
|
+
o.on('-b', '--binary') do
|
181
|
+
options.binary = true
|
182
|
+
end
|
183
|
+
|
184
|
+
o.on('--recurrent-size NUMBER') do |size|
|
185
|
+
options.recurrent_size = size.to_i
|
186
|
+
end
|
187
|
+
|
188
|
+
o.on('--activation NAME') do |name|
|
189
|
+
options.activation_function = CooCoo::ActivationFunctions.from_name(name)
|
190
|
+
end
|
191
|
+
|
192
|
+
o.on('--epochs NUMBER') do |n|
|
193
|
+
options.epochs = n.to_i
|
194
|
+
end
|
195
|
+
|
196
|
+
o.on('--backprop-limit NUMBER') do |n|
|
197
|
+
options.backprop_limit = n.to_i
|
198
|
+
end
|
199
|
+
|
200
|
+
o.on('--hidden-size NUMBER') do |n|
|
201
|
+
options.hidden_size = n.to_i
|
202
|
+
end
|
203
|
+
|
204
|
+
o.on('--recurrent-layers NUMBER') do |n|
|
205
|
+
options.num_recurrent_layers = n.to_i
|
206
|
+
end
|
207
|
+
|
208
|
+
o.on('--softmax') do
|
209
|
+
options.softmax = true
|
210
|
+
options.cost_function = CooCoo::CostFunctions::CrossEntropy
|
211
|
+
end
|
212
|
+
|
213
|
+
o.on('-p', '--predict') do
|
214
|
+
options.trainer = nil
|
215
|
+
end
|
216
|
+
|
217
|
+
o.on('-t', '--trainer NAME') do |name|
|
218
|
+
options.trainer = CooCoo::Trainer.from_name(name)
|
219
|
+
end
|
220
|
+
|
221
|
+
o.on('-n', '--sequence-size NUMBER') do |n|
|
222
|
+
n = n.to_i
|
223
|
+
raise ArgumentError.new("sequence-size must be > 0") if n <= 0
|
224
|
+
options.sequence_size = n
|
225
|
+
end
|
226
|
+
|
227
|
+
o.on('--layers NUMBER') do |n|
|
228
|
+
n = n.to_i
|
229
|
+
raise ArgumentError.new("number of layers must be > 0") if n <= 0
|
230
|
+
options.num_layers = n
|
231
|
+
end
|
232
|
+
|
233
|
+
o.on('-g', '--generate AMOUNT') do |v|
|
234
|
+
options.generator = true
|
235
|
+
options.generator_amount = v.to_i
|
236
|
+
end
|
237
|
+
|
238
|
+
o.on('--generator-init STRING') do |v|
|
239
|
+
options.generator_init = v
|
240
|
+
end
|
241
|
+
|
242
|
+
o.on('--generator-temp NUMBER') do |v|
|
243
|
+
options.generator_temperature = v.to_i
|
244
|
+
end
|
245
|
+
|
246
|
+
o.on('--generator-sample-top') do
|
247
|
+
options.sampler = method(:sample_top)
|
248
|
+
end
|
249
|
+
|
250
|
+
o.on('--seed NUMBER') do |v|
|
251
|
+
srand(v.to_i)
|
252
|
+
end
|
253
|
+
|
254
|
+
o.on('-h', '--help') do
|
255
|
+
puts(o)
|
256
|
+
if options.trainer
|
257
|
+
t_opts, _ = options.trainer.options
|
258
|
+
puts(t_opts)
|
259
|
+
end
|
260
|
+
|
261
|
+
exit
|
262
|
+
end
|
263
|
+
end
|
264
|
+
|
265
|
+
argv = opts.parse!(ARGV)
|
266
|
+
|
267
|
+
if options.trainer
|
268
|
+
t_opts, trainer_options = options.trainer.options
|
269
|
+
argv = t_opts.parse!(argv)
|
270
|
+
end
|
271
|
+
|
272
|
+
options.input_path = argv[0]
|
273
|
+
encoder = options.encoder
|
274
|
+
options.hidden_size ||= encoder.vector_size
|
275
|
+
|
276
|
+
if File.exists?(options.model_path)
|
277
|
+
$stdout.print("Loading #{options.model_path}...")
|
278
|
+
$stdout.flush
|
279
|
+
net = if options.binary
|
280
|
+
Marshal.load(File.read(options.model_path))
|
281
|
+
else
|
282
|
+
CooCoo::TemporalNetwork.new(network: CooCoo::Network.load(options.model_path))
|
283
|
+
end
|
284
|
+
puts("\rLoaded #{options.model_path}:")
|
285
|
+
else
|
286
|
+
puts("Creating new network")
|
287
|
+
puts("\tNumber of inputs: #{encoder.vector_size}")
|
288
|
+
puts("\tNumber of layers: #{options.num_layers}")
|
289
|
+
puts("\tHidden size: #{options.hidden_size}#{' with mix' if options.hidden_size != encoder.vector_size}")
|
290
|
+
puts("\tRecurrent size: #{options.recurrent_size}")
|
291
|
+
puts("\tActivation: #{options.activation_function}")
|
292
|
+
puts("\tRecurrent layers: #{options.num_recurrent_layers}")
|
293
|
+
|
294
|
+
net = CooCoo::TemporalNetwork.new()
|
295
|
+
if options.hidden_size != encoder.vector_size
|
296
|
+
net.layer(CooCoo::Layer.new(encoder.vector_size, options.hidden_size, options.activation_function))
|
297
|
+
end
|
298
|
+
|
299
|
+
options.num_recurrent_layers.to_i.times do |n|
|
300
|
+
rec = CooCoo::Recurrence::Frontend.new(options.hidden_size, options.recurrent_size)
|
301
|
+
net.layer(rec)
|
302
|
+
options.num_layers.times do
|
303
|
+
net.layer(CooCoo::Layer.new(options.hidden_size + rec.recurrent_size, options.hidden_size + rec.recurrent_size, options.activation_function))
|
304
|
+
end
|
305
|
+
|
306
|
+
net.layer(rec.backend)
|
307
|
+
#net.layer(CooCoo::Layer.new(options.hidden_size, options.hidden_size, options.activation_function))
|
308
|
+
end
|
309
|
+
|
310
|
+
if options.hidden_size != encoder.vector_size
|
311
|
+
net.layer(CooCoo::Layer.new(options.hidden_size, encoder.vector_size, options.activation_function))
|
312
|
+
end
|
313
|
+
|
314
|
+
if options.softmax
|
315
|
+
net.layer(CooCoo::LinearLayer.new(encoder.vector_size, CooCoo::ActivationFunctions.from_name('ShiftedSoftMax')))
|
316
|
+
end
|
317
|
+
end
|
318
|
+
|
319
|
+
net.backprop_limit = options.backprop_limit
|
320
|
+
|
321
|
+
puts("\tAge: #{net.age}")
|
322
|
+
puts("\tInputs: #{net.num_inputs}")
|
323
|
+
puts("\tOutputs: #{net.num_outputs}")
|
324
|
+
puts("\tLayers: #{net.num_layers}")
|
325
|
+
|
326
|
+
data = if options.input_path
|
327
|
+
puts("Reading #{options.input_path}")
|
328
|
+
File.read(options.input_path)
|
329
|
+
else
|
330
|
+
puts("Reading stdin...")
|
331
|
+
$stdin.read
|
332
|
+
end
|
333
|
+
data = data.bytes
|
334
|
+
puts("Read #{data.size} bytes")
|
335
|
+
|
336
|
+
puts("Splitting into #{options.sequence_size} byte sequences.")
|
337
|
+
training_data = training_enumerator(data, options.sequence_size, encoder)
|
338
|
+
|
339
|
+
if options.trainer
|
340
|
+
puts("Training on #{data.size} bytes from #{options.input_path || "stdin"} in #{options.epochs} epochs in batches of #{trainer_options.batch_size} at a learning rate of #{trainer_options.learning_rate}...")
|
341
|
+
|
342
|
+
trainer = options.trainer
|
343
|
+
bar = CooCoo::ProgressBar.create(:total => (options.epochs * data.size / trainer_options.batch_size.to_f).ceil)
|
344
|
+
trainer.train({ network: net,
|
345
|
+
data: training_data.cycle(options.epochs),
|
346
|
+
}.merge(trainer_options.to_h)) do |stats|
|
347
|
+
cost = stats.average_loss.average
|
348
|
+
raise 'Cost went to NAN' if cost.nan?
|
349
|
+
status = [ "Cost #{cost.average} #{options.verbose ? cost : ''}" ]
|
350
|
+
|
351
|
+
File.write_to(options.model_path) do |f|
|
352
|
+
if options.binary
|
353
|
+
f.puts(Marshal.dump(net))
|
354
|
+
else
|
355
|
+
f.puts(net.to_hash.to_yaml)
|
356
|
+
end
|
357
|
+
end
|
358
|
+
status << "Saved to #{options.model_path}"
|
359
|
+
bar.log(status.join("\n"))
|
360
|
+
bar.increment
|
361
|
+
end
|
362
|
+
end
|
363
|
+
|
364
|
+
if options.generator
|
365
|
+
o, hidden_state = net.predict(encoder.encode_string(options.generator_init), {})
|
366
|
+
data.each do |b|
|
367
|
+
o, hidden_state = net.predict(encoder.encode_input(b), hidden_state)
|
368
|
+
end
|
369
|
+
options.generator_amount.to_i.times do |n|
|
370
|
+
c = options.sampler.call(o, options.generator_temperature)
|
371
|
+
c = encoder.decode_byte(c) if encoder.vector_size != 256
|
372
|
+
$stdout.write(c.chr)
|
373
|
+
$stdout.flush
|
374
|
+
o, hidden_state = net.predict(encoder.encode_input(c), hidden_state)
|
375
|
+
end
|
376
|
+
|
377
|
+
$stdout.puts
|
378
|
+
else
|
379
|
+
puts("Predicting:")
|
380
|
+
hidden_state = nil
|
381
|
+
s = data.size.times.collect do |i|
|
382
|
+
input = data[i, options.sequence_size].collect { |e| encoder.encode_input(e || 0) }
|
383
|
+
output, hidden_state = net.predict(input, hidden_state)
|
384
|
+
output.collect { |b| encoder.decode_output(b) }
|
385
|
+
end
|
386
|
+
|
387
|
+
s.each_with_index do |c, i|
|
388
|
+
input = data[i, options.sequence_size]
|
389
|
+
puts("#{i} #{input.inspect} -> #{c.inspect}\t#{encoder.decode_sequence(input)} -> #{encoder.decode_sequence(c)}")
|
390
|
+
end
|
391
|
+
|
392
|
+
puts(encoder.decode_sequence(s.collect(&:first)))
|
393
|
+
puts(encoder.decode_sequence(s.collect(&:last)))
|
394
|
+
|
395
|
+
hidden_state = nil
|
396
|
+
c = data[rand(data.size)]
|
397
|
+
s = data.size.times.collect do |i|
|
398
|
+
o, hidden_state = net.predict(encoder.encode_input(c), hidden_state)
|
399
|
+
c = encoder.decode_output(o)
|
400
|
+
end
|
401
|
+
|
402
|
+
puts(s.inspect)
|
403
|
+
puts(encoder.decode_sequence(s))
|
404
|
+
end
|
405
|
+
end
|