CooCoo 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. checksums.yaml +7 -0
  2. data/.gitignore +16 -0
  3. data/CooCoo.gemspec +47 -0
  4. data/Gemfile +4 -0
  5. data/Gemfile.lock +88 -0
  6. data/README.md +123 -0
  7. data/Rakefile +81 -0
  8. data/bin/cuda-dev-info +25 -0
  9. data/bin/cuda-free +28 -0
  10. data/bin/cuda-free-trend +7 -0
  11. data/bin/ffi-gen +267 -0
  12. data/bin/spec_runner_html.sh +42 -0
  13. data/bin/trainer +198 -0
  14. data/bin/trend-cost +13 -0
  15. data/examples/char-rnn.rb +405 -0
  16. data/examples/cifar/cifar.rb +94 -0
  17. data/examples/img-similarity.rb +201 -0
  18. data/examples/math_ops.rb +57 -0
  19. data/examples/mnist.rb +365 -0
  20. data/examples/mnist_classifier.rb +293 -0
  21. data/examples/mnist_dream.rb +214 -0
  22. data/examples/seeds.rb +268 -0
  23. data/examples/seeds_dataset.txt +210 -0
  24. data/examples/t10k-images-idx3-ubyte +0 -0
  25. data/examples/t10k-labels-idx1-ubyte +0 -0
  26. data/examples/train-images-idx3-ubyte +0 -0
  27. data/examples/train-labels-idx1-ubyte +0 -0
  28. data/ext/buffer/Rakefile +50 -0
  29. data/ext/buffer/buffer.pre.cu +727 -0
  30. data/ext/buffer/matrix.pre.cu +49 -0
  31. data/lib/CooCoo.rb +1 -0
  32. data/lib/coo-coo.rb +18 -0
  33. data/lib/coo-coo/activation_functions.rb +344 -0
  34. data/lib/coo-coo/consts.rb +5 -0
  35. data/lib/coo-coo/convolution.rb +298 -0
  36. data/lib/coo-coo/core_ext.rb +75 -0
  37. data/lib/coo-coo/cost_functions.rb +91 -0
  38. data/lib/coo-coo/cuda.rb +116 -0
  39. data/lib/coo-coo/cuda/device_buffer.rb +240 -0
  40. data/lib/coo-coo/cuda/device_buffer/ffi.rb +109 -0
  41. data/lib/coo-coo/cuda/error.rb +51 -0
  42. data/lib/coo-coo/cuda/host_buffer.rb +117 -0
  43. data/lib/coo-coo/cuda/runtime.rb +157 -0
  44. data/lib/coo-coo/cuda/vector.rb +315 -0
  45. data/lib/coo-coo/data_sources.rb +2 -0
  46. data/lib/coo-coo/data_sources/xournal.rb +25 -0
  47. data/lib/coo-coo/data_sources/xournal/bitmap_stream.rb +197 -0
  48. data/lib/coo-coo/data_sources/xournal/document.rb +377 -0
  49. data/lib/coo-coo/data_sources/xournal/loader.rb +144 -0
  50. data/lib/coo-coo/data_sources/xournal/renderer.rb +101 -0
  51. data/lib/coo-coo/data_sources/xournal/saver.rb +99 -0
  52. data/lib/coo-coo/data_sources/xournal/training_document.rb +78 -0
  53. data/lib/coo-coo/data_sources/xournal/training_document/constants.rb +15 -0
  54. data/lib/coo-coo/data_sources/xournal/training_document/document_maker.rb +89 -0
  55. data/lib/coo-coo/data_sources/xournal/training_document/document_reader.rb +105 -0
  56. data/lib/coo-coo/data_sources/xournal/training_document/example.rb +37 -0
  57. data/lib/coo-coo/data_sources/xournal/training_document/sets.rb +76 -0
  58. data/lib/coo-coo/debug.rb +8 -0
  59. data/lib/coo-coo/dot.rb +129 -0
  60. data/lib/coo-coo/drawing.rb +4 -0
  61. data/lib/coo-coo/drawing/cairo_canvas.rb +100 -0
  62. data/lib/coo-coo/drawing/canvas.rb +68 -0
  63. data/lib/coo-coo/drawing/chunky_canvas.rb +101 -0
  64. data/lib/coo-coo/drawing/sixel.rb +214 -0
  65. data/lib/coo-coo/enum.rb +17 -0
  66. data/lib/coo-coo/from_name.rb +58 -0
  67. data/lib/coo-coo/fully_connected_layer.rb +205 -0
  68. data/lib/coo-coo/generation_script.rb +38 -0
  69. data/lib/coo-coo/grapher.rb +140 -0
  70. data/lib/coo-coo/image.rb +286 -0
  71. data/lib/coo-coo/layer.rb +67 -0
  72. data/lib/coo-coo/layer_factory.rb +26 -0
  73. data/lib/coo-coo/linear_layer.rb +59 -0
  74. data/lib/coo-coo/math.rb +607 -0
  75. data/lib/coo-coo/math/abstract_vector.rb +121 -0
  76. data/lib/coo-coo/math/functions.rb +39 -0
  77. data/lib/coo-coo/math/interpolation.rb +7 -0
  78. data/lib/coo-coo/network.rb +264 -0
  79. data/lib/coo-coo/neuron.rb +112 -0
  80. data/lib/coo-coo/neuron_layer.rb +168 -0
  81. data/lib/coo-coo/option_parser.rb +18 -0
  82. data/lib/coo-coo/platform.rb +17 -0
  83. data/lib/coo-coo/progress_bar.rb +11 -0
  84. data/lib/coo-coo/recurrence/backend.rb +99 -0
  85. data/lib/coo-coo/recurrence/frontend.rb +101 -0
  86. data/lib/coo-coo/sequence.rb +187 -0
  87. data/lib/coo-coo/shell.rb +2 -0
  88. data/lib/coo-coo/temporal_network.rb +291 -0
  89. data/lib/coo-coo/trainer.rb +21 -0
  90. data/lib/coo-coo/trainer/base.rb +67 -0
  91. data/lib/coo-coo/trainer/batch.rb +82 -0
  92. data/lib/coo-coo/trainer/batch_stats.rb +27 -0
  93. data/lib/coo-coo/trainer/momentum_stochastic.rb +59 -0
  94. data/lib/coo-coo/trainer/stochastic.rb +47 -0
  95. data/lib/coo-coo/transformer.rb +272 -0
  96. data/lib/coo-coo/vector_layer.rb +194 -0
  97. data/lib/coo-coo/version.rb +3 -0
  98. data/lib/coo-coo/weight_deltas.rb +23 -0
  99. data/prototypes/convolution.rb +116 -0
  100. data/prototypes/linear_drop.rb +51 -0
  101. data/prototypes/recurrent_layers.rb +79 -0
  102. data/www/images/screamer.png +0 -0
  103. data/www/images/screamer.xcf +0 -0
  104. data/www/index.html +82 -0
  105. metadata +373 -0
@@ -0,0 +1,293 @@
1
+ #!/bin/env ruby
2
+
3
+ require 'fileutils'
4
+ require 'mnist'
5
+ require 'ostruct'
6
+ require 'coo-coo'
7
+ require 'coo-coo/image'
8
+ require 'coo-coo/convolution'
9
+ require 'coo-coo/neuron_layer'
10
+ require 'coo-coo/subnet'
11
+ require 'coo-coo/drawing/sixel'
12
+ require 'colorize'
13
+
14
+ def backup(path)
15
+ if File.exists?(path)
16
+ backup = path.to_s + "~"
17
+ if File.exists?(backup)
18
+ File.delete(backup)
19
+ end
20
+ FileUtils.copy(path, backup)
21
+ end
22
+ end
23
+
24
+ options = OpenStruct.new
25
+ options.examples = 0
26
+ options.epochs = 1
27
+ options.num_tests = 10
28
+ options.start_tests_at = 0
29
+ options.rotations = 8
30
+ options.max_rotation = 90.0
31
+ options.num_translations = 1
32
+ options.translate_dx = 0
33
+ options.translate_dy = 0
34
+ options.hidden_layers = nil
35
+ options.hidden_size = 128
36
+ options.activation_function = CooCoo.default_activation
37
+ options.trainer = 'Stochastic'
38
+ options.softmax = false
39
+ options.convolution = nil
40
+ options.conv_step = 8
41
+ options.stacked_convolution = false
42
+ options.test_images_path = MNist::TEST_IMAGES_PATH
43
+ options.test_labels_path = MNist::TEST_LABELS_PATH
44
+
45
+ opts = CooCoo::OptionParser.new do |o|
46
+ o.on('-h', '--help') do
47
+ puts(o)
48
+ if options.trainer
49
+ t = CooCoo::Trainer.from_name(options.trainer)
50
+ raise NameError.new("Unknown trainer #{options.trainer}") unless t
51
+ opts, _ = t.options
52
+ puts(opts)
53
+ end
54
+ exit
55
+ end
56
+
57
+ o.on('--sixel') do
58
+ options.sixel = true
59
+ end
60
+
61
+ o.on('-m', '--model PATH') do |path|
62
+ options.model_path = Pathname.new(path)
63
+ options.binary_blob = File.extname(options.model_path) == '.bin'
64
+ end
65
+
66
+ o.on('--binary') do
67
+ options.binary_blob = true
68
+ end
69
+
70
+ o.on('-t', '--train NUMBER', 'train for number of epochs') do |n|
71
+ options.train = true
72
+ options.epochs = n.to_i
73
+ end
74
+
75
+ o.on('-e', '--examples NUMBER') do |n|
76
+ options.examples = n.to_i
77
+ end
78
+
79
+ o.on('-p', '--predict NUMBER') do |n|
80
+ options.num_tests = n.to_i
81
+ end
82
+
83
+ o.on('-s', '--skip NUMBER') do |n|
84
+ options.start_tests_at = n.to_i
85
+ end
86
+
87
+ o.on('-r', '--rotations NUMBER') do |n|
88
+ options.rotations = n.to_i
89
+ end
90
+
91
+ o.on('-a', '--angle NUMBER') do |n|
92
+ options.max_rotation = n.to_f
93
+ end
94
+
95
+ o.on('--num-translations NUMBER') do |n|
96
+ options.num_translations = n.to_i
97
+ end
98
+
99
+ o.on('--delta-x NUMBER') do |dx|
100
+ options.translate_dx = dx.to_f
101
+ end
102
+
103
+ o.on('--delta-y NUMBER') do |dy|
104
+ options.translate_dy = dy.to_f
105
+ end
106
+
107
+ o.on('-l', '--hidden-layers NUMBER') do |n|
108
+ options.hidden_layers = n.to_i
109
+ end
110
+
111
+ o.on('--hidden-size NUMBER') do |n|
112
+ options.hidden_size = n.to_i
113
+ end
114
+
115
+ o.on('-f', '--activation-func FUNC') do |func|
116
+ options.activation_function = CooCoo::ActivationFunctions.from_name(func)
117
+ end
118
+
119
+ o.on('--trainer NAME') do |name|
120
+ options.trainer = name
121
+ end
122
+
123
+ o.on('--softmax') do
124
+ options.softmax = true
125
+ end
126
+
127
+ o.on('--convolution') do
128
+ options.convolution = true
129
+ end
130
+
131
+ o.on('--convolution-step NUMBER') do |n|
132
+ n = n.to_i
133
+ raise ArgumentError.new("The convolution step must be >0.") if n <= 0
134
+ options.conv_step = n
135
+ end
136
+ end
137
+
138
+ argv = opts.parse!(ARGV)
139
+ max_rad = options.max_rotation.to_f * Math::PI / 180.0
140
+
141
+ trainer = nil
142
+ trainer_options = nil
143
+ if options.trainer
144
+ trainer = CooCoo::Trainer.from_name(options.trainer)
145
+ raise NameError.new("Unknown trainer #{options.trainer}") unless trainer
146
+ t_opts, trainer_options = trainer.options
147
+ argv = t_opts.parse!(argv)
148
+ end
149
+
150
+
151
+ raise ArgumentError.new("The convolution step must be >=8 when stacking convolutions.") if options.conv_step < 8
152
+
153
+ puts("Loading MNist data")
154
+ data = MNist::DataStream.new
155
+
156
+ net = CooCoo::Network.new
157
+
158
+ if options.model_path && File.exists?(options.model_path)
159
+ puts("Loading #{options.model_path}")
160
+ if options.binary_blob
161
+ net = Marshal.load(File.read(options.model_path))
162
+ else
163
+ net.load!(options.model_path)
164
+ end
165
+ else
166
+ area = data.width * data.height
167
+
168
+ if options.convolution
169
+ l = CooCoo::Convolution::BoxLayer.new(data.width, data.height, options.conv_step, options.conv_step, CooCoo::Layer.new(16, 4, options.activation_function), 4, 4, 2, 2)
170
+ net.layer(l)
171
+ area = l.size
172
+ end
173
+
174
+ # net.layer(CooCoo::Layer.new(area, 50, options.activation_function))
175
+ # net.layer(CooCoo::Layer.new(50, 20, , options.activation_function))
176
+ # net.layer(CooCoo::Layer.new(20, 10, options.activation_function))
177
+
178
+ #net.layer(CooCoo::Layer.new(area, 10, options.activation_function))
179
+
180
+ if options.hidden_layers
181
+ net.layer(CooCoo::Layer.new(area, options.hidden_size, options.activation_function))
182
+ if options.hidden_layers > 2
183
+ (options.hidden_layers - 2).times do
184
+ net.layer(CooCoo::Layer.new(options.hidden_size, options.hidden_size, options.activation_function))
185
+ end
186
+ end
187
+ net.layer(CooCoo::Layer.new(options.hidden_size, 10, options.activation_function))
188
+ else
189
+ net.layer(CooCoo::Layer.new(area, area / 4, options.activation_function))
190
+ net.layer(CooCoo::Layer.new(area / 4, 10, options.activation_function))
191
+ end
192
+
193
+ #net.layer(CooCoo::Convolution::BoxLayer.new(7, 7, CooCoo::Layer.new(16, 4), 4, 4, 2, 2))
194
+ #net.layer(CooCoo::Layer.new(14 * 14, 10))
195
+
196
+ if options.softmax
197
+ net.layer(CooCoo::LinearLayer.new(10, CooCoo::ActivationFunctions::ShiftedSoftMax.instance))
198
+ end
199
+ end
200
+
201
+ puts("Net ready:")
202
+ puts("\tAge: #{net.age}")
203
+ puts("\tActivation: #{net.activation_function}")
204
+ puts("\tInputs: #{net.num_inputs}")
205
+ puts("\tOutputs: #{net.num_outputs}")
206
+ puts("\tLayers: #{net.num_layers}")
207
+ net.layers.each_with_index do |l, i|
208
+ puts("\t\t#{i}\t#{l.num_inputs}\t#{l.size}\t#{l.class}")
209
+ end
210
+
211
+ $stdout.flush
212
+
213
+ if options.train
214
+ if options.model_path
215
+ backup(options.model_path)
216
+ end
217
+
218
+ data_r = MNist::DataStream::Rotator.new(data, options.rotations, max_rad, false)
219
+ data_t = MNist::DataStream::Translator.new(data_r, options.num_translations, options.translate_dx, options.translate_dy, false)
220
+ training_set = MNist::TrainingSet.new(data_t).each
221
+
222
+ ts = training_set.each
223
+ if options.examples > 0
224
+ ts = ts.first(options.examples * options.rotations)
225
+ end
226
+ if options.epochs > 1
227
+ ts = ts.cycle(options.epochs)
228
+ end
229
+
230
+ nex = options.examples * options.rotations * options.num_translations
231
+ nex = "all" if nex == 0
232
+ puts("Training #{nex} examples in #{trainer_options.batch_size} sized batches at a rate of #{trainer_options.learning_rate} with #{trainer.name}.")
233
+
234
+ trainer.train({ network: net,
235
+ data: ts
236
+ }.merge(trainer_options.to_h)) do |stats|
237
+ avg_err = stats.average_loss
238
+ raise "Cost went to NAN" if avg_err.nan?
239
+ puts("Cost\t#{avg_err.average}")
240
+ puts(" Magnitude\t#{avg_err.magnitude}")
241
+
242
+ if options.model_path
243
+ puts("Batch #{stats.batch} took #{stats.total_time} seconds")
244
+ puts("Saving to #{options.model_path}")
245
+ if options.binary_blob
246
+ File.open(options.model_path, 'wb') do |f|
247
+ f.write(Marshal.dump(net))
248
+ end
249
+ else
250
+ net.save(options.model_path)
251
+ end
252
+ end
253
+
254
+ $stdout.flush
255
+ end
256
+ end
257
+
258
+ CHECKMARK = "\u2714"
259
+ CROSSMARK = "\u2718"
260
+
261
+ puts("Trying the training images")
262
+ errors = Array.new(options.num_tests, 0)
263
+ data = MNist::DataStream.new(options.test_labels_path, options.test_images_path)
264
+ data_r = MNist::DataStream::Rotator.new(data.each.
265
+ drop(options.start_tests_at).
266
+ first(options.num_tests),
267
+ 1, max_rad, true)
268
+ data_t = MNist::DataStream::Translator.new(data_r, 1, options.translate_dx, options.translate_dy, true)
269
+ data_t.
270
+ each_with_index do |example, i|
271
+ output, hidden_state = net.predict(CooCoo::Vector[example.pixels, data.width * data.height, 0] / 256.0, Hash.new, true)
272
+ max_outputs = output.each_with_index.sort.reverse
273
+ max_output = max_outputs.first[1]
274
+ passed = example.label == max_output
275
+ color = passed ? :green : :red
276
+ mark = passed ? CHECKMARK : CROSSMARK
277
+ errors[i] = 1.0 unless passed
278
+ sixel = if options.sixel
279
+ " for " + CooCoo::Drawing::Sixel.to_string do |s|
280
+ 16.times { |i| c = i / 16.0 * 100; s.set_color(i, c, c, c) }
281
+ s.from_array(CooCoo::Vector[example.each_pixel.collect.to_a] * 16.0 / 256.0, 28, 28)
282
+ end
283
+ else
284
+ "\n"
285
+ end
286
+
287
+ puts("#{mark.send(color)} #{i.to_s.send(color)}\tExpecting: #{example.label}#{sixel}\tAngle: #{example.angle * 180.0 / Math::PI}\n\tOffset: #{example.offset_x} #{example.offset_y}\n\tGot: #{max_output}\t#{max_output == example.label}\n\tOutputs: #{output}\n\tBest guesses: #{max_outputs.first(3).inspect}")
288
+ if example.label != max_output
289
+ puts("#{example.to_ascii}")
290
+ end
291
+ end
292
+
293
+ puts("Errors: #{errors.each.sum / options.num_tests.to_f * 100.0}% (#{errors.each.sum}/#{options.num_tests})")
@@ -0,0 +1,214 @@
1
+ require 'coo-coo'
2
+ require 'ostruct'
3
+ require 'coo-coo/drawing/sixel'
4
+ require 'colorize'
5
+
6
+ $use_color = true
7
+ PixelValues = ' -+%X#'
8
+ ColorValues = [ :black, :red, :green, :blue, :magenta, :white ]
9
+
10
+ def char_for_pixel(p)
11
+ PixelValues[(p * (PixelValues.length - 1)).to_i] || PixelValues[0]
12
+ end
13
+
14
+ def color_for_pixel(p)
15
+ ColorValues[(p * (ColorValues.length - 1)).to_i] || ColorValues[0]
16
+ end
17
+
18
+ def output_to_ascii(output)
19
+ output = output.minmax_normalize(true)
20
+
21
+ s = ""
22
+ w = Math.sqrt(output.size).to_i
23
+ w.times do |y|
24
+ w.times do |x|
25
+ v = output[x + y * w]
26
+ v = 1.0 if v > 1.0
27
+ v = 0.0 if v < 0.0
28
+ c = char_for_pixel(v)
29
+ c = c.colorize(color_for_pixel(v)) if $use_color
30
+ s += c
31
+ end
32
+ s += "\n"
33
+ end
34
+ s
35
+ end
36
+
37
+ def output_to_sixel(output)
38
+ output = output.minmax_normalize(true)
39
+
40
+ CooCoo::Drawing::Sixel.to_string do |s|
41
+ 16.times { |i| c = i / 16.0 * 100; s.set_color(i, c, c, c) }
42
+ w = Math.sqrt(output.size).to_i
43
+ s.from_array(output * 16, w, w)
44
+ end
45
+ end
46
+
47
+ def sgd(opts)
48
+ f = opts.fetch(:f)
49
+ cost = opts.fetch(:cost)
50
+ loss = opts.fetch(:loss)
51
+ update = opts.fetch(:update)
52
+ #on_batch = opts.fetch(:on_batch)
53
+ status = opts.fetch(:status)
54
+ epochs = opts.fetch(:epochs, 1)
55
+ rate = opts.fetch(:rate)
56
+ verbose = opts.fetch(:verbose, false)
57
+ status_time = opts.fetch(:status_time, Float::INFINITY)
58
+
59
+ last_time = Time.now
60
+ last_deltas = 0.0 # CooCoo::Vector.zeros(28 * 28)
61
+ c = nil
62
+ output = nil
63
+ deltas = nil
64
+
65
+ epochs.times do |e|
66
+ output = f.call()
67
+ c = cost.call(*output)
68
+ deltas = loss.call(c, *output) * rate
69
+ update.call(deltas, last_deltas * rate)
70
+ last_deltas = deltas
71
+ dt = Time.now - last_time
72
+ if status && verbose && dt > status_time
73
+ status.call({ dt: dt,
74
+ epoch: e,
75
+ output: output,
76
+ cost: c,
77
+ deltas: deltas
78
+ })
79
+ last_time = Time.now
80
+ end
81
+ end
82
+
83
+ if status && verbose
84
+ status.call({ dt: Time.now - last_time,
85
+ epoch: epochs,
86
+ output: output,
87
+ cost: c,
88
+ deltas: deltas
89
+ })
90
+ end
91
+ end
92
+
93
+ def backprop_digit(loops, rate, net, digit, initial_input = nil, verbose = false, status_delay = 5.0, to_ascii = true, to_sixel = false)
94
+ initial_input ||= CooCoo::Vector.zeros(net.num_inputs)
95
+ input = initial_input
96
+ target = CooCoo::Vector.zeros(net.num_outputs)
97
+ target[digit % net.num_outputs] = 1.0
98
+ target = net.prep_output_target(target)
99
+
100
+ sgd(epochs: loops, rate: rate, status_time: status_delay, verbose: verbose,
101
+ f: lambda do
102
+ output, hs = net.forward(input, {}, true, true)
103
+ end,
104
+ cost: lambda do |output, hs|
105
+ output.last - target
106
+ end,
107
+ loss: lambda do |c, output, hs|
108
+ deltas, hs = net.backprop(input, output, c, hs)
109
+ errs = net.transfer_errors(deltas)
110
+ x = errs.first
111
+ end,
112
+ update: lambda do |deltas, last_deltas|
113
+ input = input - deltas + last_deltas
114
+ end,
115
+ status: lambda do |opts|
116
+ puts("#{opts[:epoch]} #{digit} Input")
117
+ puts(output_to_sixel(input)) if to_sixel
118
+ puts(output_to_ascii(input)) if to_ascii
119
+ puts("Output: #{opts[:output][0].last[digit]}\t#{opts[:output][0].last}\n")
120
+ puts("Cost: #{opts[:cost].magnitude}\t#{opts[:cost]}\n")
121
+ puts
122
+ end)
123
+
124
+ input
125
+ end
126
+
127
+ options = OpenStruct.new
128
+ options.model_path = nil
129
+ options.loops = 10
130
+ options.rate = 0.5
131
+ options.initial_input = CooCoo::Vector.zeros(28 * 28)
132
+ options.status_delay = 5.0
133
+ options.ascii = true
134
+ options.sixel = false
135
+
136
+ opts = CooCoo::OptionParser.new do |o|
137
+ o.on('--print-values BOOL') do |bool|
138
+ options.print_values = bool =~ /(1|t(rue)?|f(false)?|y(es)?)/
139
+ end
140
+
141
+ o.on('--sixel', "toggles on the display of the dream as a Sixel graphic") do
142
+ options.sixel = !options.sixel
143
+ end
144
+
145
+ o.on('--ascii', "toggles off the display of the dream as ASCII") do
146
+ options.ascii = !options.ascii
147
+ end
148
+
149
+ o.on('--color BOOL', 'toggles the use of color in the ASCII dream') do |bool|
150
+ $use_color = bool =~ /(1|t(rue)?|f(false)?|y(es)?)/
151
+ end
152
+
153
+ o.on('-m', '--model PATH') do |path|
154
+ options.model_path = Pathname.new(path)
155
+ end
156
+
157
+ o.on('-l', '--loops NUMBER') do |n|
158
+ options.loops = n.to_i
159
+ end
160
+
161
+ o.on('-r', '--rate NUMBER') do |n|
162
+ options.rate = n.to_f
163
+ end
164
+
165
+ o.on('-v', '--verbose') do
166
+ options.verbose = true
167
+ end
168
+
169
+ o.on('--status-delay SECONDS') do |n|
170
+ options.status_delay = n.to_f
171
+ end
172
+
173
+ o.on('-i', '--initial NAME') do |n|
174
+ options.initial_input = case n[0].downcase
175
+ when 'o' then CooCoo::Vector.ones(28 * 28)
176
+ when 'r' then CooCoo::Vector.rand(28 * 28)
177
+ when 'z' then CooCoo::Vector.zeros(28 * 28)
178
+ when 'h' then CooCoo::Vector.new(28 * 28, 0.5)
179
+ else raise ArgumentError.new("Unknown initial value #{n}")
180
+ end
181
+ end
182
+ end
183
+
184
+ argv = opts.parse!(ARGV)
185
+ net = if File.extname(options.model_path) == '.bin'
186
+ Marshal.load(File.read(options.model_path))
187
+ else
188
+ CooCoo::Network.load(options.model_path)
189
+ end
190
+
191
+ argv = 10.times if argv.empty?
192
+
193
+ argv.collect do |digit|
194
+ digit = digit.to_i
195
+ $stdout.puts("Generating #{digit}") if options.verbose
196
+ input = backprop_digit(options.loops, options.rate, net, digit.to_i, options.initial_input, options.verbose, options.status_delay, options.ascii, options.sixel)
197
+ $stdout.flush
198
+ [ digit, input ]
199
+ end.each do |digit, input|
200
+ output, hs = net.predict(input, {})
201
+ passed = output[digit] > 0.8
202
+ color = passed ? :green : :red
203
+ status_char = passed ? "\u2714" : "\u2718"
204
+
205
+ puts("#{digit}".colorize(color))
206
+ puts('=' * 8)
207
+ puts
208
+ puts(output_to_sixel(input)) if options.sixel
209
+ puts(output_to_ascii(input)) if options.ascii
210
+ puts(input) if options.print_values
211
+ puts
212
+ puts("#{status_char.colorize(color)} Output #{output[digit]} #{output.magnitude} #{options.verbose ? output.inspect : ''}")
213
+ puts
214
+ end