CooCoo 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (105) hide show
  1. checksums.yaml +7 -0
  2. data/.gitignore +16 -0
  3. data/CooCoo.gemspec +47 -0
  4. data/Gemfile +4 -0
  5. data/Gemfile.lock +88 -0
  6. data/README.md +123 -0
  7. data/Rakefile +81 -0
  8. data/bin/cuda-dev-info +25 -0
  9. data/bin/cuda-free +28 -0
  10. data/bin/cuda-free-trend +7 -0
  11. data/bin/ffi-gen +267 -0
  12. data/bin/spec_runner_html.sh +42 -0
  13. data/bin/trainer +198 -0
  14. data/bin/trend-cost +13 -0
  15. data/examples/char-rnn.rb +405 -0
  16. data/examples/cifar/cifar.rb +94 -0
  17. data/examples/img-similarity.rb +201 -0
  18. data/examples/math_ops.rb +57 -0
  19. data/examples/mnist.rb +365 -0
  20. data/examples/mnist_classifier.rb +293 -0
  21. data/examples/mnist_dream.rb +214 -0
  22. data/examples/seeds.rb +268 -0
  23. data/examples/seeds_dataset.txt +210 -0
  24. data/examples/t10k-images-idx3-ubyte +0 -0
  25. data/examples/t10k-labels-idx1-ubyte +0 -0
  26. data/examples/train-images-idx3-ubyte +0 -0
  27. data/examples/train-labels-idx1-ubyte +0 -0
  28. data/ext/buffer/Rakefile +50 -0
  29. data/ext/buffer/buffer.pre.cu +727 -0
  30. data/ext/buffer/matrix.pre.cu +49 -0
  31. data/lib/CooCoo.rb +1 -0
  32. data/lib/coo-coo.rb +18 -0
  33. data/lib/coo-coo/activation_functions.rb +344 -0
  34. data/lib/coo-coo/consts.rb +5 -0
  35. data/lib/coo-coo/convolution.rb +298 -0
  36. data/lib/coo-coo/core_ext.rb +75 -0
  37. data/lib/coo-coo/cost_functions.rb +91 -0
  38. data/lib/coo-coo/cuda.rb +116 -0
  39. data/lib/coo-coo/cuda/device_buffer.rb +240 -0
  40. data/lib/coo-coo/cuda/device_buffer/ffi.rb +109 -0
  41. data/lib/coo-coo/cuda/error.rb +51 -0
  42. data/lib/coo-coo/cuda/host_buffer.rb +117 -0
  43. data/lib/coo-coo/cuda/runtime.rb +157 -0
  44. data/lib/coo-coo/cuda/vector.rb +315 -0
  45. data/lib/coo-coo/data_sources.rb +2 -0
  46. data/lib/coo-coo/data_sources/xournal.rb +25 -0
  47. data/lib/coo-coo/data_sources/xournal/bitmap_stream.rb +197 -0
  48. data/lib/coo-coo/data_sources/xournal/document.rb +377 -0
  49. data/lib/coo-coo/data_sources/xournal/loader.rb +144 -0
  50. data/lib/coo-coo/data_sources/xournal/renderer.rb +101 -0
  51. data/lib/coo-coo/data_sources/xournal/saver.rb +99 -0
  52. data/lib/coo-coo/data_sources/xournal/training_document.rb +78 -0
  53. data/lib/coo-coo/data_sources/xournal/training_document/constants.rb +15 -0
  54. data/lib/coo-coo/data_sources/xournal/training_document/document_maker.rb +89 -0
  55. data/lib/coo-coo/data_sources/xournal/training_document/document_reader.rb +105 -0
  56. data/lib/coo-coo/data_sources/xournal/training_document/example.rb +37 -0
  57. data/lib/coo-coo/data_sources/xournal/training_document/sets.rb +76 -0
  58. data/lib/coo-coo/debug.rb +8 -0
  59. data/lib/coo-coo/dot.rb +129 -0
  60. data/lib/coo-coo/drawing.rb +4 -0
  61. data/lib/coo-coo/drawing/cairo_canvas.rb +100 -0
  62. data/lib/coo-coo/drawing/canvas.rb +68 -0
  63. data/lib/coo-coo/drawing/chunky_canvas.rb +101 -0
  64. data/lib/coo-coo/drawing/sixel.rb +214 -0
  65. data/lib/coo-coo/enum.rb +17 -0
  66. data/lib/coo-coo/from_name.rb +58 -0
  67. data/lib/coo-coo/fully_connected_layer.rb +205 -0
  68. data/lib/coo-coo/generation_script.rb +38 -0
  69. data/lib/coo-coo/grapher.rb +140 -0
  70. data/lib/coo-coo/image.rb +286 -0
  71. data/lib/coo-coo/layer.rb +67 -0
  72. data/lib/coo-coo/layer_factory.rb +26 -0
  73. data/lib/coo-coo/linear_layer.rb +59 -0
  74. data/lib/coo-coo/math.rb +607 -0
  75. data/lib/coo-coo/math/abstract_vector.rb +121 -0
  76. data/lib/coo-coo/math/functions.rb +39 -0
  77. data/lib/coo-coo/math/interpolation.rb +7 -0
  78. data/lib/coo-coo/network.rb +264 -0
  79. data/lib/coo-coo/neuron.rb +112 -0
  80. data/lib/coo-coo/neuron_layer.rb +168 -0
  81. data/lib/coo-coo/option_parser.rb +18 -0
  82. data/lib/coo-coo/platform.rb +17 -0
  83. data/lib/coo-coo/progress_bar.rb +11 -0
  84. data/lib/coo-coo/recurrence/backend.rb +99 -0
  85. data/lib/coo-coo/recurrence/frontend.rb +101 -0
  86. data/lib/coo-coo/sequence.rb +187 -0
  87. data/lib/coo-coo/shell.rb +2 -0
  88. data/lib/coo-coo/temporal_network.rb +291 -0
  89. data/lib/coo-coo/trainer.rb +21 -0
  90. data/lib/coo-coo/trainer/base.rb +67 -0
  91. data/lib/coo-coo/trainer/batch.rb +82 -0
  92. data/lib/coo-coo/trainer/batch_stats.rb +27 -0
  93. data/lib/coo-coo/trainer/momentum_stochastic.rb +59 -0
  94. data/lib/coo-coo/trainer/stochastic.rb +47 -0
  95. data/lib/coo-coo/transformer.rb +272 -0
  96. data/lib/coo-coo/vector_layer.rb +194 -0
  97. data/lib/coo-coo/version.rb +3 -0
  98. data/lib/coo-coo/weight_deltas.rb +23 -0
  99. data/prototypes/convolution.rb +116 -0
  100. data/prototypes/linear_drop.rb +51 -0
  101. data/prototypes/recurrent_layers.rb +79 -0
  102. data/www/images/screamer.png +0 -0
  103. data/www/images/screamer.xcf +0 -0
  104. data/www/index.html +82 -0
  105. metadata +373 -0
@@ -0,0 +1,293 @@
1
+ #!/bin/env ruby
2
+
3
+ require 'fileutils'
4
+ require 'mnist'
5
+ require 'ostruct'
6
+ require 'coo-coo'
7
+ require 'coo-coo/image'
8
+ require 'coo-coo/convolution'
9
+ require 'coo-coo/neuron_layer'
10
+ require 'coo-coo/subnet'
11
+ require 'coo-coo/drawing/sixel'
12
+ require 'colorize'
13
+
14
+ def backup(path)
15
+ if File.exists?(path)
16
+ backup = path.to_s + "~"
17
+ if File.exists?(backup)
18
+ File.delete(backup)
19
+ end
20
+ FileUtils.copy(path, backup)
21
+ end
22
+ end
23
+
24
+ options = OpenStruct.new
25
+ options.examples = 0
26
+ options.epochs = 1
27
+ options.num_tests = 10
28
+ options.start_tests_at = 0
29
+ options.rotations = 8
30
+ options.max_rotation = 90.0
31
+ options.num_translations = 1
32
+ options.translate_dx = 0
33
+ options.translate_dy = 0
34
+ options.hidden_layers = nil
35
+ options.hidden_size = 128
36
+ options.activation_function = CooCoo.default_activation
37
+ options.trainer = 'Stochastic'
38
+ options.softmax = false
39
+ options.convolution = nil
40
+ options.conv_step = 8
41
+ options.stacked_convolution = false
42
+ options.test_images_path = MNist::TEST_IMAGES_PATH
43
+ options.test_labels_path = MNist::TEST_LABELS_PATH
44
+
45
+ opts = CooCoo::OptionParser.new do |o|
46
+ o.on('-h', '--help') do
47
+ puts(o)
48
+ if options.trainer
49
+ t = CooCoo::Trainer.from_name(options.trainer)
50
+ raise NameError.new("Unknown trainer #{options.trainer}") unless t
51
+ opts, _ = t.options
52
+ puts(opts)
53
+ end
54
+ exit
55
+ end
56
+
57
+ o.on('--sixel') do
58
+ options.sixel = true
59
+ end
60
+
61
+ o.on('-m', '--model PATH') do |path|
62
+ options.model_path = Pathname.new(path)
63
+ options.binary_blob = File.extname(options.model_path) == '.bin'
64
+ end
65
+
66
+ o.on('--binary') do
67
+ options.binary_blob = true
68
+ end
69
+
70
+ o.on('-t', '--train NUMBER', 'train for number of epochs') do |n|
71
+ options.train = true
72
+ options.epochs = n.to_i
73
+ end
74
+
75
+ o.on('-e', '--examples NUMBER') do |n|
76
+ options.examples = n.to_i
77
+ end
78
+
79
+ o.on('-p', '--predict NUMBER') do |n|
80
+ options.num_tests = n.to_i
81
+ end
82
+
83
+ o.on('-s', '--skip NUMBER') do |n|
84
+ options.start_tests_at = n.to_i
85
+ end
86
+
87
+ o.on('-r', '--rotations NUMBER') do |n|
88
+ options.rotations = n.to_i
89
+ end
90
+
91
+ o.on('-a', '--angle NUMBER') do |n|
92
+ options.max_rotation = n.to_f
93
+ end
94
+
95
+ o.on('--num-translations NUMBER') do |n|
96
+ options.num_translations = n.to_i
97
+ end
98
+
99
+ o.on('--delta-x NUMBER') do |dx|
100
+ options.translate_dx = dx.to_f
101
+ end
102
+
103
+ o.on('--delta-y NUMBER') do |dy|
104
+ options.translate_dy = dy.to_f
105
+ end
106
+
107
+ o.on('-l', '--hidden-layers NUMBER') do |n|
108
+ options.hidden_layers = n.to_i
109
+ end
110
+
111
+ o.on('--hidden-size NUMBER') do |n|
112
+ options.hidden_size = n.to_i
113
+ end
114
+
115
+ o.on('-f', '--activation-func FUNC') do |func|
116
+ options.activation_function = CooCoo::ActivationFunctions.from_name(func)
117
+ end
118
+
119
+ o.on('--trainer NAME') do |name|
120
+ options.trainer = name
121
+ end
122
+
123
+ o.on('--softmax') do
124
+ options.softmax = true
125
+ end
126
+
127
+ o.on('--convolution') do
128
+ options.convolution = true
129
+ end
130
+
131
+ o.on('--convolution-step NUMBER') do |n|
132
+ n = n.to_i
133
+ raise ArgumentError.new("The convolution step must be >0.") if n <= 0
134
+ options.conv_step = n
135
+ end
136
+ end
137
+
138
+ argv = opts.parse!(ARGV)
139
+ max_rad = options.max_rotation.to_f * Math::PI / 180.0
140
+
141
+ trainer = nil
142
+ trainer_options = nil
143
+ if options.trainer
144
+ trainer = CooCoo::Trainer.from_name(options.trainer)
145
+ raise NameError.new("Unknown trainer #{options.trainer}") unless trainer
146
+ t_opts, trainer_options = trainer.options
147
+ argv = t_opts.parse!(argv)
148
+ end
149
+
150
+
151
+ raise ArgumentError.new("The convolution step must be >=8 when stacking convolutions.") if options.conv_step < 8
152
+
153
+ puts("Loading MNist data")
154
+ data = MNist::DataStream.new
155
+
156
+ net = CooCoo::Network.new
157
+
158
+ if options.model_path && File.exists?(options.model_path)
159
+ puts("Loading #{options.model_path}")
160
+ if options.binary_blob
161
+ net = Marshal.load(File.read(options.model_path))
162
+ else
163
+ net.load!(options.model_path)
164
+ end
165
+ else
166
+ area = data.width * data.height
167
+
168
+ if options.convolution
169
+ l = CooCoo::Convolution::BoxLayer.new(data.width, data.height, options.conv_step, options.conv_step, CooCoo::Layer.new(16, 4, options.activation_function), 4, 4, 2, 2)
170
+ net.layer(l)
171
+ area = l.size
172
+ end
173
+
174
+ # net.layer(CooCoo::Layer.new(area, 50, options.activation_function))
175
+ # net.layer(CooCoo::Layer.new(50, 20, , options.activation_function))
176
+ # net.layer(CooCoo::Layer.new(20, 10, options.activation_function))
177
+
178
+ #net.layer(CooCoo::Layer.new(area, 10, options.activation_function))
179
+
180
+ if options.hidden_layers
181
+ net.layer(CooCoo::Layer.new(area, options.hidden_size, options.activation_function))
182
+ if options.hidden_layers > 2
183
+ (options.hidden_layers - 2).times do
184
+ net.layer(CooCoo::Layer.new(options.hidden_size, options.hidden_size, options.activation_function))
185
+ end
186
+ end
187
+ net.layer(CooCoo::Layer.new(options.hidden_size, 10, options.activation_function))
188
+ else
189
+ net.layer(CooCoo::Layer.new(area, area / 4, options.activation_function))
190
+ net.layer(CooCoo::Layer.new(area / 4, 10, options.activation_function))
191
+ end
192
+
193
+ #net.layer(CooCoo::Convolution::BoxLayer.new(7, 7, CooCoo::Layer.new(16, 4), 4, 4, 2, 2))
194
+ #net.layer(CooCoo::Layer.new(14 * 14, 10))
195
+
196
+ if options.softmax
197
+ net.layer(CooCoo::LinearLayer.new(10, CooCoo::ActivationFunctions::ShiftedSoftMax.instance))
198
+ end
199
+ end
200
+
201
+ puts("Net ready:")
202
+ puts("\tAge: #{net.age}")
203
+ puts("\tActivation: #{net.activation_function}")
204
+ puts("\tInputs: #{net.num_inputs}")
205
+ puts("\tOutputs: #{net.num_outputs}")
206
+ puts("\tLayers: #{net.num_layers}")
207
+ net.layers.each_with_index do |l, i|
208
+ puts("\t\t#{i}\t#{l.num_inputs}\t#{l.size}\t#{l.class}")
209
+ end
210
+
211
+ $stdout.flush
212
+
213
+ if options.train
214
+ if options.model_path
215
+ backup(options.model_path)
216
+ end
217
+
218
+ data_r = MNist::DataStream::Rotator.new(data, options.rotations, max_rad, false)
219
+ data_t = MNist::DataStream::Translator.new(data_r, options.num_translations, options.translate_dx, options.translate_dy, false)
220
+ training_set = MNist::TrainingSet.new(data_t).each
221
+
222
+ ts = training_set.each
223
+ if options.examples > 0
224
+ ts = ts.first(options.examples * options.rotations)
225
+ end
226
+ if options.epochs > 1
227
+ ts = ts.cycle(options.epochs)
228
+ end
229
+
230
+ nex = options.examples * options.rotations * options.num_translations
231
+ nex = "all" if nex == 0
232
+ puts("Training #{nex} examples in #{trainer_options.batch_size} sized batches at a rate of #{trainer_options.learning_rate} with #{trainer.name}.")
233
+
234
+ trainer.train({ network: net,
235
+ data: ts
236
+ }.merge(trainer_options.to_h)) do |stats|
237
+ avg_err = stats.average_loss
238
+ raise "Cost went to NAN" if avg_err.nan?
239
+ puts("Cost\t#{avg_err.average}")
240
+ puts(" Magnitude\t#{avg_err.magnitude}")
241
+
242
+ if options.model_path
243
+ puts("Batch #{stats.batch} took #{stats.total_time} seconds")
244
+ puts("Saving to #{options.model_path}")
245
+ if options.binary_blob
246
+ File.open(options.model_path, 'wb') do |f|
247
+ f.write(Marshal.dump(net))
248
+ end
249
+ else
250
+ net.save(options.model_path)
251
+ end
252
+ end
253
+
254
+ $stdout.flush
255
+ end
256
+ end
257
+
258
+ CHECKMARK = "\u2714"
259
+ CROSSMARK = "\u2718"
260
+
261
+ puts("Trying the training images")
262
+ errors = Array.new(options.num_tests, 0)
263
+ data = MNist::DataStream.new(options.test_labels_path, options.test_images_path)
264
+ data_r = MNist::DataStream::Rotator.new(data.each.
265
+ drop(options.start_tests_at).
266
+ first(options.num_tests),
267
+ 1, max_rad, true)
268
+ data_t = MNist::DataStream::Translator.new(data_r, 1, options.translate_dx, options.translate_dy, true)
269
+ data_t.
270
+ each_with_index do |example, i|
271
+ output, hidden_state = net.predict(CooCoo::Vector[example.pixels, data.width * data.height, 0] / 256.0, Hash.new, true)
272
+ max_outputs = output.each_with_index.sort.reverse
273
+ max_output = max_outputs.first[1]
274
+ passed = example.label == max_output
275
+ color = passed ? :green : :red
276
+ mark = passed ? CHECKMARK : CROSSMARK
277
+ errors[i] = 1.0 unless passed
278
+ sixel = if options.sixel
279
+ " for " + CooCoo::Drawing::Sixel.to_string do |s|
280
+ 16.times { |i| c = i / 16.0 * 100; s.set_color(i, c, c, c) }
281
+ s.from_array(CooCoo::Vector[example.each_pixel.collect.to_a] * 16.0 / 256.0, 28, 28)
282
+ end
283
+ else
284
+ "\n"
285
+ end
286
+
287
+ puts("#{mark.send(color)} #{i.to_s.send(color)}\tExpecting: #{example.label}#{sixel}\tAngle: #{example.angle * 180.0 / Math::PI}\n\tOffset: #{example.offset_x} #{example.offset_y}\n\tGot: #{max_output}\t#{max_output == example.label}\n\tOutputs: #{output}\n\tBest guesses: #{max_outputs.first(3).inspect}")
288
+ if example.label != max_output
289
+ puts("#{example.to_ascii}")
290
+ end
291
+ end
292
+
293
+ puts("Errors: #{errors.each.sum / options.num_tests.to_f * 100.0}% (#{errors.each.sum}/#{options.num_tests})")
@@ -0,0 +1,214 @@
1
+ require 'coo-coo'
2
+ require 'ostruct'
3
+ require 'coo-coo/drawing/sixel'
4
+ require 'colorize'
5
+
6
+ $use_color = true
7
+ PixelValues = ' -+%X#'
8
+ ColorValues = [ :black, :red, :green, :blue, :magenta, :white ]
9
+
10
+ def char_for_pixel(p)
11
+ PixelValues[(p * (PixelValues.length - 1)).to_i] || PixelValues[0]
12
+ end
13
+
14
+ def color_for_pixel(p)
15
+ ColorValues[(p * (ColorValues.length - 1)).to_i] || ColorValues[0]
16
+ end
17
+
18
+ def output_to_ascii(output)
19
+ output = output.minmax_normalize(true)
20
+
21
+ s = ""
22
+ w = Math.sqrt(output.size).to_i
23
+ w.times do |y|
24
+ w.times do |x|
25
+ v = output[x + y * w]
26
+ v = 1.0 if v > 1.0
27
+ v = 0.0 if v < 0.0
28
+ c = char_for_pixel(v)
29
+ c = c.colorize(color_for_pixel(v)) if $use_color
30
+ s += c
31
+ end
32
+ s += "\n"
33
+ end
34
+ s
35
+ end
36
+
37
+ def output_to_sixel(output)
38
+ output = output.minmax_normalize(true)
39
+
40
+ CooCoo::Drawing::Sixel.to_string do |s|
41
+ 16.times { |i| c = i / 16.0 * 100; s.set_color(i, c, c, c) }
42
+ w = Math.sqrt(output.size).to_i
43
+ s.from_array(output * 16, w, w)
44
+ end
45
+ end
46
+
47
+ def sgd(opts)
48
+ f = opts.fetch(:f)
49
+ cost = opts.fetch(:cost)
50
+ loss = opts.fetch(:loss)
51
+ update = opts.fetch(:update)
52
+ #on_batch = opts.fetch(:on_batch)
53
+ status = opts.fetch(:status)
54
+ epochs = opts.fetch(:epochs, 1)
55
+ rate = opts.fetch(:rate)
56
+ verbose = opts.fetch(:verbose, false)
57
+ status_time = opts.fetch(:status_time, Float::INFINITY)
58
+
59
+ last_time = Time.now
60
+ last_deltas = 0.0 # CooCoo::Vector.zeros(28 * 28)
61
+ c = nil
62
+ output = nil
63
+ deltas = nil
64
+
65
+ epochs.times do |e|
66
+ output = f.call()
67
+ c = cost.call(*output)
68
+ deltas = loss.call(c, *output) * rate
69
+ update.call(deltas, last_deltas * rate)
70
+ last_deltas = deltas
71
+ dt = Time.now - last_time
72
+ if status && verbose && dt > status_time
73
+ status.call({ dt: dt,
74
+ epoch: e,
75
+ output: output,
76
+ cost: c,
77
+ deltas: deltas
78
+ })
79
+ last_time = Time.now
80
+ end
81
+ end
82
+
83
+ if status && verbose
84
+ status.call({ dt: Time.now - last_time,
85
+ epoch: epochs,
86
+ output: output,
87
+ cost: c,
88
+ deltas: deltas
89
+ })
90
+ end
91
+ end
92
+
93
+ def backprop_digit(loops, rate, net, digit, initial_input = nil, verbose = false, status_delay = 5.0, to_ascii = true, to_sixel = false)
94
+ initial_input ||= CooCoo::Vector.zeros(net.num_inputs)
95
+ input = initial_input
96
+ target = CooCoo::Vector.zeros(net.num_outputs)
97
+ target[digit % net.num_outputs] = 1.0
98
+ target = net.prep_output_target(target)
99
+
100
+ sgd(epochs: loops, rate: rate, status_time: status_delay, verbose: verbose,
101
+ f: lambda do
102
+ output, hs = net.forward(input, {}, true, true)
103
+ end,
104
+ cost: lambda do |output, hs|
105
+ output.last - target
106
+ end,
107
+ loss: lambda do |c, output, hs|
108
+ deltas, hs = net.backprop(input, output, c, hs)
109
+ errs = net.transfer_errors(deltas)
110
+ x = errs.first
111
+ end,
112
+ update: lambda do |deltas, last_deltas|
113
+ input = input - deltas + last_deltas
114
+ end,
115
+ status: lambda do |opts|
116
+ puts("#{opts[:epoch]} #{digit} Input")
117
+ puts(output_to_sixel(input)) if to_sixel
118
+ puts(output_to_ascii(input)) if to_ascii
119
+ puts("Output: #{opts[:output][0].last[digit]}\t#{opts[:output][0].last}\n")
120
+ puts("Cost: #{opts[:cost].magnitude}\t#{opts[:cost]}\n")
121
+ puts
122
+ end)
123
+
124
+ input
125
+ end
126
+
127
+ options = OpenStruct.new
128
+ options.model_path = nil
129
+ options.loops = 10
130
+ options.rate = 0.5
131
+ options.initial_input = CooCoo::Vector.zeros(28 * 28)
132
+ options.status_delay = 5.0
133
+ options.ascii = true
134
+ options.sixel = false
135
+
136
+ opts = CooCoo::OptionParser.new do |o|
137
+ o.on('--print-values BOOL') do |bool|
138
+ options.print_values = bool =~ /(1|t(rue)?|f(false)?|y(es)?)/
139
+ end
140
+
141
+ o.on('--sixel', "toggles on the display of the dream as a Sixel graphic") do
142
+ options.sixel = !options.sixel
143
+ end
144
+
145
+ o.on('--ascii', "toggles off the display of the dream as ASCII") do
146
+ options.ascii = !options.ascii
147
+ end
148
+
149
+ o.on('--color BOOL', 'toggles the use of color in the ASCII dream') do |bool|
150
+ $use_color = bool =~ /(1|t(rue)?|f(false)?|y(es)?)/
151
+ end
152
+
153
+ o.on('-m', '--model PATH') do |path|
154
+ options.model_path = Pathname.new(path)
155
+ end
156
+
157
+ o.on('-l', '--loops NUMBER') do |n|
158
+ options.loops = n.to_i
159
+ end
160
+
161
+ o.on('-r', '--rate NUMBER') do |n|
162
+ options.rate = n.to_f
163
+ end
164
+
165
+ o.on('-v', '--verbose') do
166
+ options.verbose = true
167
+ end
168
+
169
+ o.on('--status-delay SECONDS') do |n|
170
+ options.status_delay = n.to_f
171
+ end
172
+
173
+ o.on('-i', '--initial NAME') do |n|
174
+ options.initial_input = case n[0].downcase
175
+ when 'o' then CooCoo::Vector.ones(28 * 28)
176
+ when 'r' then CooCoo::Vector.rand(28 * 28)
177
+ when 'z' then CooCoo::Vector.zeros(28 * 28)
178
+ when 'h' then CooCoo::Vector.new(28 * 28, 0.5)
179
+ else raise ArgumentError.new("Unknown initial value #{n}")
180
+ end
181
+ end
182
+ end
183
+
184
+ argv = opts.parse!(ARGV)
185
+ net = if File.extname(options.model_path) == '.bin'
186
+ Marshal.load(File.read(options.model_path))
187
+ else
188
+ CooCoo::Network.load(options.model_path)
189
+ end
190
+
191
+ argv = 10.times if argv.empty?
192
+
193
+ argv.collect do |digit|
194
+ digit = digit.to_i
195
+ $stdout.puts("Generating #{digit}") if options.verbose
196
+ input = backprop_digit(options.loops, options.rate, net, digit.to_i, options.initial_input, options.verbose, options.status_delay, options.ascii, options.sixel)
197
+ $stdout.flush
198
+ [ digit, input ]
199
+ end.each do |digit, input|
200
+ output, hs = net.predict(input, {})
201
+ passed = output[digit] > 0.8
202
+ color = passed ? :green : :red
203
+ status_char = passed ? "\u2714" : "\u2718"
204
+
205
+ puts("#{digit}".colorize(color))
206
+ puts('=' * 8)
207
+ puts
208
+ puts(output_to_sixel(input)) if options.sixel
209
+ puts(output_to_ascii(input)) if options.ascii
210
+ puts(input) if options.print_values
211
+ puts
212
+ puts("#{status_char.colorize(color)} Output #{output[digit]} #{output.magnitude} #{options.verbose ? output.inspect : ''}")
213
+ puts
214
+ end