CooCoo 0.1.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/.gitignore +16 -0
- data/CooCoo.gemspec +47 -0
- data/Gemfile +4 -0
- data/Gemfile.lock +88 -0
- data/README.md +123 -0
- data/Rakefile +81 -0
- data/bin/cuda-dev-info +25 -0
- data/bin/cuda-free +28 -0
- data/bin/cuda-free-trend +7 -0
- data/bin/ffi-gen +267 -0
- data/bin/spec_runner_html.sh +42 -0
- data/bin/trainer +198 -0
- data/bin/trend-cost +13 -0
- data/examples/char-rnn.rb +405 -0
- data/examples/cifar/cifar.rb +94 -0
- data/examples/img-similarity.rb +201 -0
- data/examples/math_ops.rb +57 -0
- data/examples/mnist.rb +365 -0
- data/examples/mnist_classifier.rb +293 -0
- data/examples/mnist_dream.rb +214 -0
- data/examples/seeds.rb +268 -0
- data/examples/seeds_dataset.txt +210 -0
- data/examples/t10k-images-idx3-ubyte +0 -0
- data/examples/t10k-labels-idx1-ubyte +0 -0
- data/examples/train-images-idx3-ubyte +0 -0
- data/examples/train-labels-idx1-ubyte +0 -0
- data/ext/buffer/Rakefile +50 -0
- data/ext/buffer/buffer.pre.cu +727 -0
- data/ext/buffer/matrix.pre.cu +49 -0
- data/lib/CooCoo.rb +1 -0
- data/lib/coo-coo.rb +18 -0
- data/lib/coo-coo/activation_functions.rb +344 -0
- data/lib/coo-coo/consts.rb +5 -0
- data/lib/coo-coo/convolution.rb +298 -0
- data/lib/coo-coo/core_ext.rb +75 -0
- data/lib/coo-coo/cost_functions.rb +91 -0
- data/lib/coo-coo/cuda.rb +116 -0
- data/lib/coo-coo/cuda/device_buffer.rb +240 -0
- data/lib/coo-coo/cuda/device_buffer/ffi.rb +109 -0
- data/lib/coo-coo/cuda/error.rb +51 -0
- data/lib/coo-coo/cuda/host_buffer.rb +117 -0
- data/lib/coo-coo/cuda/runtime.rb +157 -0
- data/lib/coo-coo/cuda/vector.rb +315 -0
- data/lib/coo-coo/data_sources.rb +2 -0
- data/lib/coo-coo/data_sources/xournal.rb +25 -0
- data/lib/coo-coo/data_sources/xournal/bitmap_stream.rb +197 -0
- data/lib/coo-coo/data_sources/xournal/document.rb +377 -0
- data/lib/coo-coo/data_sources/xournal/loader.rb +144 -0
- data/lib/coo-coo/data_sources/xournal/renderer.rb +101 -0
- data/lib/coo-coo/data_sources/xournal/saver.rb +99 -0
- data/lib/coo-coo/data_sources/xournal/training_document.rb +78 -0
- data/lib/coo-coo/data_sources/xournal/training_document/constants.rb +15 -0
- data/lib/coo-coo/data_sources/xournal/training_document/document_maker.rb +89 -0
- data/lib/coo-coo/data_sources/xournal/training_document/document_reader.rb +105 -0
- data/lib/coo-coo/data_sources/xournal/training_document/example.rb +37 -0
- data/lib/coo-coo/data_sources/xournal/training_document/sets.rb +76 -0
- data/lib/coo-coo/debug.rb +8 -0
- data/lib/coo-coo/dot.rb +129 -0
- data/lib/coo-coo/drawing.rb +4 -0
- data/lib/coo-coo/drawing/cairo_canvas.rb +100 -0
- data/lib/coo-coo/drawing/canvas.rb +68 -0
- data/lib/coo-coo/drawing/chunky_canvas.rb +101 -0
- data/lib/coo-coo/drawing/sixel.rb +214 -0
- data/lib/coo-coo/enum.rb +17 -0
- data/lib/coo-coo/from_name.rb +58 -0
- data/lib/coo-coo/fully_connected_layer.rb +205 -0
- data/lib/coo-coo/generation_script.rb +38 -0
- data/lib/coo-coo/grapher.rb +140 -0
- data/lib/coo-coo/image.rb +286 -0
- data/lib/coo-coo/layer.rb +67 -0
- data/lib/coo-coo/layer_factory.rb +26 -0
- data/lib/coo-coo/linear_layer.rb +59 -0
- data/lib/coo-coo/math.rb +607 -0
- data/lib/coo-coo/math/abstract_vector.rb +121 -0
- data/lib/coo-coo/math/functions.rb +39 -0
- data/lib/coo-coo/math/interpolation.rb +7 -0
- data/lib/coo-coo/network.rb +264 -0
- data/lib/coo-coo/neuron.rb +112 -0
- data/lib/coo-coo/neuron_layer.rb +168 -0
- data/lib/coo-coo/option_parser.rb +18 -0
- data/lib/coo-coo/platform.rb +17 -0
- data/lib/coo-coo/progress_bar.rb +11 -0
- data/lib/coo-coo/recurrence/backend.rb +99 -0
- data/lib/coo-coo/recurrence/frontend.rb +101 -0
- data/lib/coo-coo/sequence.rb +187 -0
- data/lib/coo-coo/shell.rb +2 -0
- data/lib/coo-coo/temporal_network.rb +291 -0
- data/lib/coo-coo/trainer.rb +21 -0
- data/lib/coo-coo/trainer/base.rb +67 -0
- data/lib/coo-coo/trainer/batch.rb +82 -0
- data/lib/coo-coo/trainer/batch_stats.rb +27 -0
- data/lib/coo-coo/trainer/momentum_stochastic.rb +59 -0
- data/lib/coo-coo/trainer/stochastic.rb +47 -0
- data/lib/coo-coo/transformer.rb +272 -0
- data/lib/coo-coo/vector_layer.rb +194 -0
- data/lib/coo-coo/version.rb +3 -0
- data/lib/coo-coo/weight_deltas.rb +23 -0
- data/prototypes/convolution.rb +116 -0
- data/prototypes/linear_drop.rb +51 -0
- data/prototypes/recurrent_layers.rb +79 -0
- data/www/images/screamer.png +0 -0
- data/www/images/screamer.xcf +0 -0
- data/www/index.html +82 -0
- metadata +373 -0
@@ -0,0 +1,293 @@
|
|
1
|
+
#!/bin/env ruby
|
2
|
+
|
3
|
+
require 'fileutils'
|
4
|
+
require 'mnist'
|
5
|
+
require 'ostruct'
|
6
|
+
require 'coo-coo'
|
7
|
+
require 'coo-coo/image'
|
8
|
+
require 'coo-coo/convolution'
|
9
|
+
require 'coo-coo/neuron_layer'
|
10
|
+
require 'coo-coo/subnet'
|
11
|
+
require 'coo-coo/drawing/sixel'
|
12
|
+
require 'colorize'
|
13
|
+
|
14
|
+
def backup(path)
|
15
|
+
if File.exists?(path)
|
16
|
+
backup = path.to_s + "~"
|
17
|
+
if File.exists?(backup)
|
18
|
+
File.delete(backup)
|
19
|
+
end
|
20
|
+
FileUtils.copy(path, backup)
|
21
|
+
end
|
22
|
+
end
|
23
|
+
|
24
|
+
options = OpenStruct.new
|
25
|
+
options.examples = 0
|
26
|
+
options.epochs = 1
|
27
|
+
options.num_tests = 10
|
28
|
+
options.start_tests_at = 0
|
29
|
+
options.rotations = 8
|
30
|
+
options.max_rotation = 90.0
|
31
|
+
options.num_translations = 1
|
32
|
+
options.translate_dx = 0
|
33
|
+
options.translate_dy = 0
|
34
|
+
options.hidden_layers = nil
|
35
|
+
options.hidden_size = 128
|
36
|
+
options.activation_function = CooCoo.default_activation
|
37
|
+
options.trainer = 'Stochastic'
|
38
|
+
options.softmax = false
|
39
|
+
options.convolution = nil
|
40
|
+
options.conv_step = 8
|
41
|
+
options.stacked_convolution = false
|
42
|
+
options.test_images_path = MNist::TEST_IMAGES_PATH
|
43
|
+
options.test_labels_path = MNist::TEST_LABELS_PATH
|
44
|
+
|
45
|
+
opts = CooCoo::OptionParser.new do |o|
|
46
|
+
o.on('-h', '--help') do
|
47
|
+
puts(o)
|
48
|
+
if options.trainer
|
49
|
+
t = CooCoo::Trainer.from_name(options.trainer)
|
50
|
+
raise NameError.new("Unknown trainer #{options.trainer}") unless t
|
51
|
+
opts, _ = t.options
|
52
|
+
puts(opts)
|
53
|
+
end
|
54
|
+
exit
|
55
|
+
end
|
56
|
+
|
57
|
+
o.on('--sixel') do
|
58
|
+
options.sixel = true
|
59
|
+
end
|
60
|
+
|
61
|
+
o.on('-m', '--model PATH') do |path|
|
62
|
+
options.model_path = Pathname.new(path)
|
63
|
+
options.binary_blob = File.extname(options.model_path) == '.bin'
|
64
|
+
end
|
65
|
+
|
66
|
+
o.on('--binary') do
|
67
|
+
options.binary_blob = true
|
68
|
+
end
|
69
|
+
|
70
|
+
o.on('-t', '--train NUMBER', 'train for number of epochs') do |n|
|
71
|
+
options.train = true
|
72
|
+
options.epochs = n.to_i
|
73
|
+
end
|
74
|
+
|
75
|
+
o.on('-e', '--examples NUMBER') do |n|
|
76
|
+
options.examples = n.to_i
|
77
|
+
end
|
78
|
+
|
79
|
+
o.on('-p', '--predict NUMBER') do |n|
|
80
|
+
options.num_tests = n.to_i
|
81
|
+
end
|
82
|
+
|
83
|
+
o.on('-s', '--skip NUMBER') do |n|
|
84
|
+
options.start_tests_at = n.to_i
|
85
|
+
end
|
86
|
+
|
87
|
+
o.on('-r', '--rotations NUMBER') do |n|
|
88
|
+
options.rotations = n.to_i
|
89
|
+
end
|
90
|
+
|
91
|
+
o.on('-a', '--angle NUMBER') do |n|
|
92
|
+
options.max_rotation = n.to_f
|
93
|
+
end
|
94
|
+
|
95
|
+
o.on('--num-translations NUMBER') do |n|
|
96
|
+
options.num_translations = n.to_i
|
97
|
+
end
|
98
|
+
|
99
|
+
o.on('--delta-x NUMBER') do |dx|
|
100
|
+
options.translate_dx = dx.to_f
|
101
|
+
end
|
102
|
+
|
103
|
+
o.on('--delta-y NUMBER') do |dy|
|
104
|
+
options.translate_dy = dy.to_f
|
105
|
+
end
|
106
|
+
|
107
|
+
o.on('-l', '--hidden-layers NUMBER') do |n|
|
108
|
+
options.hidden_layers = n.to_i
|
109
|
+
end
|
110
|
+
|
111
|
+
o.on('--hidden-size NUMBER') do |n|
|
112
|
+
options.hidden_size = n.to_i
|
113
|
+
end
|
114
|
+
|
115
|
+
o.on('-f', '--activation-func FUNC') do |func|
|
116
|
+
options.activation_function = CooCoo::ActivationFunctions.from_name(func)
|
117
|
+
end
|
118
|
+
|
119
|
+
o.on('--trainer NAME') do |name|
|
120
|
+
options.trainer = name
|
121
|
+
end
|
122
|
+
|
123
|
+
o.on('--softmax') do
|
124
|
+
options.softmax = true
|
125
|
+
end
|
126
|
+
|
127
|
+
o.on('--convolution') do
|
128
|
+
options.convolution = true
|
129
|
+
end
|
130
|
+
|
131
|
+
o.on('--convolution-step NUMBER') do |n|
|
132
|
+
n = n.to_i
|
133
|
+
raise ArgumentError.new("The convolution step must be >0.") if n <= 0
|
134
|
+
options.conv_step = n
|
135
|
+
end
|
136
|
+
end
|
137
|
+
|
138
|
+
argv = opts.parse!(ARGV)
|
139
|
+
max_rad = options.max_rotation.to_f * Math::PI / 180.0
|
140
|
+
|
141
|
+
trainer = nil
|
142
|
+
trainer_options = nil
|
143
|
+
if options.trainer
|
144
|
+
trainer = CooCoo::Trainer.from_name(options.trainer)
|
145
|
+
raise NameError.new("Unknown trainer #{options.trainer}") unless trainer
|
146
|
+
t_opts, trainer_options = trainer.options
|
147
|
+
argv = t_opts.parse!(argv)
|
148
|
+
end
|
149
|
+
|
150
|
+
|
151
|
+
raise ArgumentError.new("The convolution step must be >=8 when stacking convolutions.") if options.conv_step < 8
|
152
|
+
|
153
|
+
puts("Loading MNist data")
|
154
|
+
data = MNist::DataStream.new
|
155
|
+
|
156
|
+
net = CooCoo::Network.new
|
157
|
+
|
158
|
+
if options.model_path && File.exists?(options.model_path)
|
159
|
+
puts("Loading #{options.model_path}")
|
160
|
+
if options.binary_blob
|
161
|
+
net = Marshal.load(File.read(options.model_path))
|
162
|
+
else
|
163
|
+
net.load!(options.model_path)
|
164
|
+
end
|
165
|
+
else
|
166
|
+
area = data.width * data.height
|
167
|
+
|
168
|
+
if options.convolution
|
169
|
+
l = CooCoo::Convolution::BoxLayer.new(data.width, data.height, options.conv_step, options.conv_step, CooCoo::Layer.new(16, 4, options.activation_function), 4, 4, 2, 2)
|
170
|
+
net.layer(l)
|
171
|
+
area = l.size
|
172
|
+
end
|
173
|
+
|
174
|
+
# net.layer(CooCoo::Layer.new(area, 50, options.activation_function))
|
175
|
+
# net.layer(CooCoo::Layer.new(50, 20, , options.activation_function))
|
176
|
+
# net.layer(CooCoo::Layer.new(20, 10, options.activation_function))
|
177
|
+
|
178
|
+
#net.layer(CooCoo::Layer.new(area, 10, options.activation_function))
|
179
|
+
|
180
|
+
if options.hidden_layers
|
181
|
+
net.layer(CooCoo::Layer.new(area, options.hidden_size, options.activation_function))
|
182
|
+
if options.hidden_layers > 2
|
183
|
+
(options.hidden_layers - 2).times do
|
184
|
+
net.layer(CooCoo::Layer.new(options.hidden_size, options.hidden_size, options.activation_function))
|
185
|
+
end
|
186
|
+
end
|
187
|
+
net.layer(CooCoo::Layer.new(options.hidden_size, 10, options.activation_function))
|
188
|
+
else
|
189
|
+
net.layer(CooCoo::Layer.new(area, area / 4, options.activation_function))
|
190
|
+
net.layer(CooCoo::Layer.new(area / 4, 10, options.activation_function))
|
191
|
+
end
|
192
|
+
|
193
|
+
#net.layer(CooCoo::Convolution::BoxLayer.new(7, 7, CooCoo::Layer.new(16, 4), 4, 4, 2, 2))
|
194
|
+
#net.layer(CooCoo::Layer.new(14 * 14, 10))
|
195
|
+
|
196
|
+
if options.softmax
|
197
|
+
net.layer(CooCoo::LinearLayer.new(10, CooCoo::ActivationFunctions::ShiftedSoftMax.instance))
|
198
|
+
end
|
199
|
+
end
|
200
|
+
|
201
|
+
puts("Net ready:")
|
202
|
+
puts("\tAge: #{net.age}")
|
203
|
+
puts("\tActivation: #{net.activation_function}")
|
204
|
+
puts("\tInputs: #{net.num_inputs}")
|
205
|
+
puts("\tOutputs: #{net.num_outputs}")
|
206
|
+
puts("\tLayers: #{net.num_layers}")
|
207
|
+
net.layers.each_with_index do |l, i|
|
208
|
+
puts("\t\t#{i}\t#{l.num_inputs}\t#{l.size}\t#{l.class}")
|
209
|
+
end
|
210
|
+
|
211
|
+
$stdout.flush
|
212
|
+
|
213
|
+
if options.train
|
214
|
+
if options.model_path
|
215
|
+
backup(options.model_path)
|
216
|
+
end
|
217
|
+
|
218
|
+
data_r = MNist::DataStream::Rotator.new(data, options.rotations, max_rad, false)
|
219
|
+
data_t = MNist::DataStream::Translator.new(data_r, options.num_translations, options.translate_dx, options.translate_dy, false)
|
220
|
+
training_set = MNist::TrainingSet.new(data_t).each
|
221
|
+
|
222
|
+
ts = training_set.each
|
223
|
+
if options.examples > 0
|
224
|
+
ts = ts.first(options.examples * options.rotations)
|
225
|
+
end
|
226
|
+
if options.epochs > 1
|
227
|
+
ts = ts.cycle(options.epochs)
|
228
|
+
end
|
229
|
+
|
230
|
+
nex = options.examples * options.rotations * options.num_translations
|
231
|
+
nex = "all" if nex == 0
|
232
|
+
puts("Training #{nex} examples in #{trainer_options.batch_size} sized batches at a rate of #{trainer_options.learning_rate} with #{trainer.name}.")
|
233
|
+
|
234
|
+
trainer.train({ network: net,
|
235
|
+
data: ts
|
236
|
+
}.merge(trainer_options.to_h)) do |stats|
|
237
|
+
avg_err = stats.average_loss
|
238
|
+
raise "Cost went to NAN" if avg_err.nan?
|
239
|
+
puts("Cost\t#{avg_err.average}")
|
240
|
+
puts(" Magnitude\t#{avg_err.magnitude}")
|
241
|
+
|
242
|
+
if options.model_path
|
243
|
+
puts("Batch #{stats.batch} took #{stats.total_time} seconds")
|
244
|
+
puts("Saving to #{options.model_path}")
|
245
|
+
if options.binary_blob
|
246
|
+
File.open(options.model_path, 'wb') do |f|
|
247
|
+
f.write(Marshal.dump(net))
|
248
|
+
end
|
249
|
+
else
|
250
|
+
net.save(options.model_path)
|
251
|
+
end
|
252
|
+
end
|
253
|
+
|
254
|
+
$stdout.flush
|
255
|
+
end
|
256
|
+
end
|
257
|
+
|
258
|
+
CHECKMARK = "\u2714"
|
259
|
+
CROSSMARK = "\u2718"
|
260
|
+
|
261
|
+
puts("Trying the training images")
|
262
|
+
errors = Array.new(options.num_tests, 0)
|
263
|
+
data = MNist::DataStream.new(options.test_labels_path, options.test_images_path)
|
264
|
+
data_r = MNist::DataStream::Rotator.new(data.each.
|
265
|
+
drop(options.start_tests_at).
|
266
|
+
first(options.num_tests),
|
267
|
+
1, max_rad, true)
|
268
|
+
data_t = MNist::DataStream::Translator.new(data_r, 1, options.translate_dx, options.translate_dy, true)
|
269
|
+
data_t.
|
270
|
+
each_with_index do |example, i|
|
271
|
+
output, hidden_state = net.predict(CooCoo::Vector[example.pixels, data.width * data.height, 0] / 256.0, Hash.new, true)
|
272
|
+
max_outputs = output.each_with_index.sort.reverse
|
273
|
+
max_output = max_outputs.first[1]
|
274
|
+
passed = example.label == max_output
|
275
|
+
color = passed ? :green : :red
|
276
|
+
mark = passed ? CHECKMARK : CROSSMARK
|
277
|
+
errors[i] = 1.0 unless passed
|
278
|
+
sixel = if options.sixel
|
279
|
+
" for " + CooCoo::Drawing::Sixel.to_string do |s|
|
280
|
+
16.times { |i| c = i / 16.0 * 100; s.set_color(i, c, c, c) }
|
281
|
+
s.from_array(CooCoo::Vector[example.each_pixel.collect.to_a] * 16.0 / 256.0, 28, 28)
|
282
|
+
end
|
283
|
+
else
|
284
|
+
"\n"
|
285
|
+
end
|
286
|
+
|
287
|
+
puts("#{mark.send(color)} #{i.to_s.send(color)}\tExpecting: #{example.label}#{sixel}\tAngle: #{example.angle * 180.0 / Math::PI}\n\tOffset: #{example.offset_x} #{example.offset_y}\n\tGot: #{max_output}\t#{max_output == example.label}\n\tOutputs: #{output}\n\tBest guesses: #{max_outputs.first(3).inspect}")
|
288
|
+
if example.label != max_output
|
289
|
+
puts("#{example.to_ascii}")
|
290
|
+
end
|
291
|
+
end
|
292
|
+
|
293
|
+
puts("Errors: #{errors.each.sum / options.num_tests.to_f * 100.0}% (#{errors.each.sum}/#{options.num_tests})")
|
@@ -0,0 +1,214 @@
|
|
1
|
+
require 'coo-coo'
|
2
|
+
require 'ostruct'
|
3
|
+
require 'coo-coo/drawing/sixel'
|
4
|
+
require 'colorize'
|
5
|
+
|
6
|
+
$use_color = true
|
7
|
+
PixelValues = ' -+%X#'
|
8
|
+
ColorValues = [ :black, :red, :green, :blue, :magenta, :white ]
|
9
|
+
|
10
|
+
def char_for_pixel(p)
|
11
|
+
PixelValues[(p * (PixelValues.length - 1)).to_i] || PixelValues[0]
|
12
|
+
end
|
13
|
+
|
14
|
+
def color_for_pixel(p)
|
15
|
+
ColorValues[(p * (ColorValues.length - 1)).to_i] || ColorValues[0]
|
16
|
+
end
|
17
|
+
|
18
|
+
def output_to_ascii(output)
|
19
|
+
output = output.minmax_normalize(true)
|
20
|
+
|
21
|
+
s = ""
|
22
|
+
w = Math.sqrt(output.size).to_i
|
23
|
+
w.times do |y|
|
24
|
+
w.times do |x|
|
25
|
+
v = output[x + y * w]
|
26
|
+
v = 1.0 if v > 1.0
|
27
|
+
v = 0.0 if v < 0.0
|
28
|
+
c = char_for_pixel(v)
|
29
|
+
c = c.colorize(color_for_pixel(v)) if $use_color
|
30
|
+
s += c
|
31
|
+
end
|
32
|
+
s += "\n"
|
33
|
+
end
|
34
|
+
s
|
35
|
+
end
|
36
|
+
|
37
|
+
def output_to_sixel(output)
|
38
|
+
output = output.minmax_normalize(true)
|
39
|
+
|
40
|
+
CooCoo::Drawing::Sixel.to_string do |s|
|
41
|
+
16.times { |i| c = i / 16.0 * 100; s.set_color(i, c, c, c) }
|
42
|
+
w = Math.sqrt(output.size).to_i
|
43
|
+
s.from_array(output * 16, w, w)
|
44
|
+
end
|
45
|
+
end
|
46
|
+
|
47
|
+
def sgd(opts)
|
48
|
+
f = opts.fetch(:f)
|
49
|
+
cost = opts.fetch(:cost)
|
50
|
+
loss = opts.fetch(:loss)
|
51
|
+
update = opts.fetch(:update)
|
52
|
+
#on_batch = opts.fetch(:on_batch)
|
53
|
+
status = opts.fetch(:status)
|
54
|
+
epochs = opts.fetch(:epochs, 1)
|
55
|
+
rate = opts.fetch(:rate)
|
56
|
+
verbose = opts.fetch(:verbose, false)
|
57
|
+
status_time = opts.fetch(:status_time, Float::INFINITY)
|
58
|
+
|
59
|
+
last_time = Time.now
|
60
|
+
last_deltas = 0.0 # CooCoo::Vector.zeros(28 * 28)
|
61
|
+
c = nil
|
62
|
+
output = nil
|
63
|
+
deltas = nil
|
64
|
+
|
65
|
+
epochs.times do |e|
|
66
|
+
output = f.call()
|
67
|
+
c = cost.call(*output)
|
68
|
+
deltas = loss.call(c, *output) * rate
|
69
|
+
update.call(deltas, last_deltas * rate)
|
70
|
+
last_deltas = deltas
|
71
|
+
dt = Time.now - last_time
|
72
|
+
if status && verbose && dt > status_time
|
73
|
+
status.call({ dt: dt,
|
74
|
+
epoch: e,
|
75
|
+
output: output,
|
76
|
+
cost: c,
|
77
|
+
deltas: deltas
|
78
|
+
})
|
79
|
+
last_time = Time.now
|
80
|
+
end
|
81
|
+
end
|
82
|
+
|
83
|
+
if status && verbose
|
84
|
+
status.call({ dt: Time.now - last_time,
|
85
|
+
epoch: epochs,
|
86
|
+
output: output,
|
87
|
+
cost: c,
|
88
|
+
deltas: deltas
|
89
|
+
})
|
90
|
+
end
|
91
|
+
end
|
92
|
+
|
93
|
+
def backprop_digit(loops, rate, net, digit, initial_input = nil, verbose = false, status_delay = 5.0, to_ascii = true, to_sixel = false)
|
94
|
+
initial_input ||= CooCoo::Vector.zeros(net.num_inputs)
|
95
|
+
input = initial_input
|
96
|
+
target = CooCoo::Vector.zeros(net.num_outputs)
|
97
|
+
target[digit % net.num_outputs] = 1.0
|
98
|
+
target = net.prep_output_target(target)
|
99
|
+
|
100
|
+
sgd(epochs: loops, rate: rate, status_time: status_delay, verbose: verbose,
|
101
|
+
f: lambda do
|
102
|
+
output, hs = net.forward(input, {}, true, true)
|
103
|
+
end,
|
104
|
+
cost: lambda do |output, hs|
|
105
|
+
output.last - target
|
106
|
+
end,
|
107
|
+
loss: lambda do |c, output, hs|
|
108
|
+
deltas, hs = net.backprop(input, output, c, hs)
|
109
|
+
errs = net.transfer_errors(deltas)
|
110
|
+
x = errs.first
|
111
|
+
end,
|
112
|
+
update: lambda do |deltas, last_deltas|
|
113
|
+
input = input - deltas + last_deltas
|
114
|
+
end,
|
115
|
+
status: lambda do |opts|
|
116
|
+
puts("#{opts[:epoch]} #{digit} Input")
|
117
|
+
puts(output_to_sixel(input)) if to_sixel
|
118
|
+
puts(output_to_ascii(input)) if to_ascii
|
119
|
+
puts("Output: #{opts[:output][0].last[digit]}\t#{opts[:output][0].last}\n")
|
120
|
+
puts("Cost: #{opts[:cost].magnitude}\t#{opts[:cost]}\n")
|
121
|
+
puts
|
122
|
+
end)
|
123
|
+
|
124
|
+
input
|
125
|
+
end
|
126
|
+
|
127
|
+
options = OpenStruct.new
|
128
|
+
options.model_path = nil
|
129
|
+
options.loops = 10
|
130
|
+
options.rate = 0.5
|
131
|
+
options.initial_input = CooCoo::Vector.zeros(28 * 28)
|
132
|
+
options.status_delay = 5.0
|
133
|
+
options.ascii = true
|
134
|
+
options.sixel = false
|
135
|
+
|
136
|
+
opts = CooCoo::OptionParser.new do |o|
|
137
|
+
o.on('--print-values BOOL') do |bool|
|
138
|
+
options.print_values = bool =~ /(1|t(rue)?|f(false)?|y(es)?)/
|
139
|
+
end
|
140
|
+
|
141
|
+
o.on('--sixel', "toggles on the display of the dream as a Sixel graphic") do
|
142
|
+
options.sixel = !options.sixel
|
143
|
+
end
|
144
|
+
|
145
|
+
o.on('--ascii', "toggles off the display of the dream as ASCII") do
|
146
|
+
options.ascii = !options.ascii
|
147
|
+
end
|
148
|
+
|
149
|
+
o.on('--color BOOL', 'toggles the use of color in the ASCII dream') do |bool|
|
150
|
+
$use_color = bool =~ /(1|t(rue)?|f(false)?|y(es)?)/
|
151
|
+
end
|
152
|
+
|
153
|
+
o.on('-m', '--model PATH') do |path|
|
154
|
+
options.model_path = Pathname.new(path)
|
155
|
+
end
|
156
|
+
|
157
|
+
o.on('-l', '--loops NUMBER') do |n|
|
158
|
+
options.loops = n.to_i
|
159
|
+
end
|
160
|
+
|
161
|
+
o.on('-r', '--rate NUMBER') do |n|
|
162
|
+
options.rate = n.to_f
|
163
|
+
end
|
164
|
+
|
165
|
+
o.on('-v', '--verbose') do
|
166
|
+
options.verbose = true
|
167
|
+
end
|
168
|
+
|
169
|
+
o.on('--status-delay SECONDS') do |n|
|
170
|
+
options.status_delay = n.to_f
|
171
|
+
end
|
172
|
+
|
173
|
+
o.on('-i', '--initial NAME') do |n|
|
174
|
+
options.initial_input = case n[0].downcase
|
175
|
+
when 'o' then CooCoo::Vector.ones(28 * 28)
|
176
|
+
when 'r' then CooCoo::Vector.rand(28 * 28)
|
177
|
+
when 'z' then CooCoo::Vector.zeros(28 * 28)
|
178
|
+
when 'h' then CooCoo::Vector.new(28 * 28, 0.5)
|
179
|
+
else raise ArgumentError.new("Unknown initial value #{n}")
|
180
|
+
end
|
181
|
+
end
|
182
|
+
end
|
183
|
+
|
184
|
+
argv = opts.parse!(ARGV)
|
185
|
+
net = if File.extname(options.model_path) == '.bin'
|
186
|
+
Marshal.load(File.read(options.model_path))
|
187
|
+
else
|
188
|
+
CooCoo::Network.load(options.model_path)
|
189
|
+
end
|
190
|
+
|
191
|
+
argv = 10.times if argv.empty?
|
192
|
+
|
193
|
+
argv.collect do |digit|
|
194
|
+
digit = digit.to_i
|
195
|
+
$stdout.puts("Generating #{digit}") if options.verbose
|
196
|
+
input = backprop_digit(options.loops, options.rate, net, digit.to_i, options.initial_input, options.verbose, options.status_delay, options.ascii, options.sixel)
|
197
|
+
$stdout.flush
|
198
|
+
[ digit, input ]
|
199
|
+
end.each do |digit, input|
|
200
|
+
output, hs = net.predict(input, {})
|
201
|
+
passed = output[digit] > 0.8
|
202
|
+
color = passed ? :green : :red
|
203
|
+
status_char = passed ? "\u2714" : "\u2718"
|
204
|
+
|
205
|
+
puts("#{digit}".colorize(color))
|
206
|
+
puts('=' * 8)
|
207
|
+
puts
|
208
|
+
puts(output_to_sixel(input)) if options.sixel
|
209
|
+
puts(output_to_ascii(input)) if options.ascii
|
210
|
+
puts(input) if options.print_values
|
211
|
+
puts
|
212
|
+
puts("#{status_char.colorize(color)} Output #{output[digit]} #{output.magnitude} #{options.verbose ? output.inspect : ''}")
|
213
|
+
puts
|
214
|
+
end
|