CooCoo 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. checksums.yaml +7 -0
  2. data/.gitignore +16 -0
  3. data/CooCoo.gemspec +47 -0
  4. data/Gemfile +4 -0
  5. data/Gemfile.lock +88 -0
  6. data/README.md +123 -0
  7. data/Rakefile +81 -0
  8. data/bin/cuda-dev-info +25 -0
  9. data/bin/cuda-free +28 -0
  10. data/bin/cuda-free-trend +7 -0
  11. data/bin/ffi-gen +267 -0
  12. data/bin/spec_runner_html.sh +42 -0
  13. data/bin/trainer +198 -0
  14. data/bin/trend-cost +13 -0
  15. data/examples/char-rnn.rb +405 -0
  16. data/examples/cifar/cifar.rb +94 -0
  17. data/examples/img-similarity.rb +201 -0
  18. data/examples/math_ops.rb +57 -0
  19. data/examples/mnist.rb +365 -0
  20. data/examples/mnist_classifier.rb +293 -0
  21. data/examples/mnist_dream.rb +214 -0
  22. data/examples/seeds.rb +268 -0
  23. data/examples/seeds_dataset.txt +210 -0
  24. data/examples/t10k-images-idx3-ubyte +0 -0
  25. data/examples/t10k-labels-idx1-ubyte +0 -0
  26. data/examples/train-images-idx3-ubyte +0 -0
  27. data/examples/train-labels-idx1-ubyte +0 -0
  28. data/ext/buffer/Rakefile +50 -0
  29. data/ext/buffer/buffer.pre.cu +727 -0
  30. data/ext/buffer/matrix.pre.cu +49 -0
  31. data/lib/CooCoo.rb +1 -0
  32. data/lib/coo-coo.rb +18 -0
  33. data/lib/coo-coo/activation_functions.rb +344 -0
  34. data/lib/coo-coo/consts.rb +5 -0
  35. data/lib/coo-coo/convolution.rb +298 -0
  36. data/lib/coo-coo/core_ext.rb +75 -0
  37. data/lib/coo-coo/cost_functions.rb +91 -0
  38. data/lib/coo-coo/cuda.rb +116 -0
  39. data/lib/coo-coo/cuda/device_buffer.rb +240 -0
  40. data/lib/coo-coo/cuda/device_buffer/ffi.rb +109 -0
  41. data/lib/coo-coo/cuda/error.rb +51 -0
  42. data/lib/coo-coo/cuda/host_buffer.rb +117 -0
  43. data/lib/coo-coo/cuda/runtime.rb +157 -0
  44. data/lib/coo-coo/cuda/vector.rb +315 -0
  45. data/lib/coo-coo/data_sources.rb +2 -0
  46. data/lib/coo-coo/data_sources/xournal.rb +25 -0
  47. data/lib/coo-coo/data_sources/xournal/bitmap_stream.rb +197 -0
  48. data/lib/coo-coo/data_sources/xournal/document.rb +377 -0
  49. data/lib/coo-coo/data_sources/xournal/loader.rb +144 -0
  50. data/lib/coo-coo/data_sources/xournal/renderer.rb +101 -0
  51. data/lib/coo-coo/data_sources/xournal/saver.rb +99 -0
  52. data/lib/coo-coo/data_sources/xournal/training_document.rb +78 -0
  53. data/lib/coo-coo/data_sources/xournal/training_document/constants.rb +15 -0
  54. data/lib/coo-coo/data_sources/xournal/training_document/document_maker.rb +89 -0
  55. data/lib/coo-coo/data_sources/xournal/training_document/document_reader.rb +105 -0
  56. data/lib/coo-coo/data_sources/xournal/training_document/example.rb +37 -0
  57. data/lib/coo-coo/data_sources/xournal/training_document/sets.rb +76 -0
  58. data/lib/coo-coo/debug.rb +8 -0
  59. data/lib/coo-coo/dot.rb +129 -0
  60. data/lib/coo-coo/drawing.rb +4 -0
  61. data/lib/coo-coo/drawing/cairo_canvas.rb +100 -0
  62. data/lib/coo-coo/drawing/canvas.rb +68 -0
  63. data/lib/coo-coo/drawing/chunky_canvas.rb +101 -0
  64. data/lib/coo-coo/drawing/sixel.rb +214 -0
  65. data/lib/coo-coo/enum.rb +17 -0
  66. data/lib/coo-coo/from_name.rb +58 -0
  67. data/lib/coo-coo/fully_connected_layer.rb +205 -0
  68. data/lib/coo-coo/generation_script.rb +38 -0
  69. data/lib/coo-coo/grapher.rb +140 -0
  70. data/lib/coo-coo/image.rb +286 -0
  71. data/lib/coo-coo/layer.rb +67 -0
  72. data/lib/coo-coo/layer_factory.rb +26 -0
  73. data/lib/coo-coo/linear_layer.rb +59 -0
  74. data/lib/coo-coo/math.rb +607 -0
  75. data/lib/coo-coo/math/abstract_vector.rb +121 -0
  76. data/lib/coo-coo/math/functions.rb +39 -0
  77. data/lib/coo-coo/math/interpolation.rb +7 -0
  78. data/lib/coo-coo/network.rb +264 -0
  79. data/lib/coo-coo/neuron.rb +112 -0
  80. data/lib/coo-coo/neuron_layer.rb +168 -0
  81. data/lib/coo-coo/option_parser.rb +18 -0
  82. data/lib/coo-coo/platform.rb +17 -0
  83. data/lib/coo-coo/progress_bar.rb +11 -0
  84. data/lib/coo-coo/recurrence/backend.rb +99 -0
  85. data/lib/coo-coo/recurrence/frontend.rb +101 -0
  86. data/lib/coo-coo/sequence.rb +187 -0
  87. data/lib/coo-coo/shell.rb +2 -0
  88. data/lib/coo-coo/temporal_network.rb +291 -0
  89. data/lib/coo-coo/trainer.rb +21 -0
  90. data/lib/coo-coo/trainer/base.rb +67 -0
  91. data/lib/coo-coo/trainer/batch.rb +82 -0
  92. data/lib/coo-coo/trainer/batch_stats.rb +27 -0
  93. data/lib/coo-coo/trainer/momentum_stochastic.rb +59 -0
  94. data/lib/coo-coo/trainer/stochastic.rb +47 -0
  95. data/lib/coo-coo/transformer.rb +272 -0
  96. data/lib/coo-coo/vector_layer.rb +194 -0
  97. data/lib/coo-coo/version.rb +3 -0
  98. data/lib/coo-coo/weight_deltas.rb +23 -0
  99. data/prototypes/convolution.rb +116 -0
  100. data/prototypes/linear_drop.rb +51 -0
  101. data/prototypes/recurrent_layers.rb +79 -0
  102. data/www/images/screamer.png +0 -0
  103. data/www/images/screamer.xcf +0 -0
  104. data/www/index.html +82 -0
  105. metadata +373 -0
@@ -0,0 +1,42 @@
1
+ #!/bin/zsh
2
+
3
+ INDEX=doc/spec/index.html
4
+
5
+ # Start the index.html
6
+ cat <<EOT > $INDEX
7
+ <html>
8
+ <head>
9
+ <style>.PASSED { color: green; } .FAILED { color: red; }</style>
10
+ </head>
11
+ <body>
12
+ <ul>
13
+ EOT
14
+
15
+ # The spec runner
16
+ function run_spec()
17
+ {
18
+ local spec="$1"
19
+ local DIR=`dirname $spec`
20
+ local BASE=`basename -s .spec $spec`
21
+ echo -n "$spec "
22
+
23
+ mkdir -p doc/$DIR && bundle exec rspec -Ilib -Iexamples $spec -f html > doc/$DIR/$BASE.html && STATE=PASSED || STATE=FAILED
24
+
25
+ echo $STATE
26
+
27
+ cat <<-EOT >> $INDEX
28
+ <li class="$STATE"><a href="../$DIR/$BASE.html">$DIR/$BASE</a> &mdash; <span class="state">$STATE</span></li>
29
+ EOT
30
+ }
31
+
32
+ # And the loop
33
+ for spec in `find spec -name \*.spec | sort`; do
34
+ run_spec "$spec"
35
+ done
36
+
37
+ # Finalize the index
38
+ cat <<EOT >> $INDEX
39
+ </ul>
40
+ </body>
41
+ </html>
42
+ EOT
@@ -0,0 +1,198 @@
1
+ #!/bin/env ruby
2
+
3
+ ROOT = File.dirname(File.dirname(__FILE__))
4
+ $: << File.join(ROOT, 'lib') << File.join(ROOT, 'examples')
5
+
6
+ require 'colorize'
7
+ require 'coo-coo'
8
+
9
+ def load_script_opts(path)
10
+ CooCoo::GenerationScript.
11
+ new(path, $stdout).
12
+ opts
13
+ end
14
+
15
+ require 'pry'
16
+ $pry = {
17
+ binding: binding,
18
+ main_thread: Thread.current
19
+ }
20
+
21
+ Signal.trap("USR1") do
22
+ $pry[:thread] ||= Thread.new do
23
+ $pry[:binding].pry
24
+ end
25
+ end
26
+
27
+ require 'ostruct'
28
+ options = OpenStruct.new
29
+ options.epochs = 1000
30
+ options.dataset = nil
31
+ options.model_path = nil
32
+ options.prototype = nil
33
+ options.trainer = nil
34
+ options.num_tests = 20
35
+ options.start_tests_at = 0
36
+ options.test_cost = CooCoo::CostFunctions::MeanSquare
37
+
38
+ opts = CooCoo::OptionParser.new do |o|
39
+ o.on('-d', '--dataset PATH') do |path|
40
+ options.dataset = path
41
+ end
42
+
43
+ o.on('-m', '--model PATH') do |path|
44
+ options.model_path = Pathname.new(path)
45
+ options.binary_blob = options.binary_blob || File.extname(options.model_path) == '.bin'
46
+ end
47
+
48
+ o.on('--prototype PATH') do |path|
49
+ options.prototype = path
50
+ end
51
+
52
+ o.on('--binary') do
53
+ options.binary_blob = true
54
+ end
55
+
56
+ o.on('-e', '--train NUMBER') do |n|
57
+ options.epochs = n.to_i
58
+ end
59
+
60
+ o.on('-p', '--predict NUMBER') do |n|
61
+ options.num_tests = n.to_i
62
+ end
63
+
64
+ o.on('-s', '--skip NUMBER') do |n|
65
+ options.start_tests_at = n.to_i
66
+ end
67
+
68
+ o.on('-t', '--trainer NAME') do |name|
69
+ options.trainer = name
70
+ end
71
+
72
+ o.on('--test-cost NAME') do |name|
73
+ options.test_cost = CooCoo::CostFunctions.from_name(name)
74
+ end
75
+
76
+ o.on('--verbose') do
77
+ options.verbose = true
78
+ end
79
+
80
+ o.on('-h', '--help') do
81
+ puts(opts)
82
+ if options.dataset
83
+ opts = load_script_opts(options.dataset.to_s)
84
+ puts(opts)
85
+ end
86
+ if options.prototype
87
+ opts = load_script_opts(options.prototype.to_s)
88
+ puts(opts)
89
+ end
90
+ if options.trainer
91
+ t = CooCoo::Trainer.from_name(options.trainer)
92
+ raise ArgumentError.new("Unknown trainer #{options.trainer}") unless t
93
+ opts, _ = t.options
94
+ puts(opts)
95
+ end
96
+ exit
97
+ end
98
+ end
99
+
100
+ argv = opts.parse!(ARGV)
101
+
102
+ puts("Loading training set #{options.dataset}")
103
+
104
+ training_set_gen = CooCoo::GenerationScript.new(options.dataset.to_s, $stdout)
105
+ argv, training_set = training_set_gen.call(argv)
106
+
107
+ if options.model_path && File.exists?(options.model_path.to_s)
108
+ puts("Loading network #{options.model_path}")
109
+ if options.binary_blob
110
+ net = Marshal.load(File.read(options.model_path))
111
+ else
112
+ net = CooCoo::Network.load(options.model_path)
113
+ end
114
+ else
115
+ puts("Generating network from #{options.prototype}")
116
+ net_gen = CooCoo::GenerationScript.new(options.prototype.to_s, $stdout)
117
+ argv, net = net_gen.call(argv, training_set)
118
+ end
119
+
120
+ trainer = nil
121
+ trainer_options = nil
122
+
123
+ if options.trainer
124
+ trainer = CooCoo::Trainer.from_name(options.trainer)
125
+ t_opts, trainer_options = trainer.options
126
+ argv = t_opts.parse!(argv)
127
+ end
128
+
129
+ unless argv.empty?
130
+ raise ArgumentError.new("Unknown arguments: #{argv.inspect}")
131
+ end
132
+
133
+ puts("Net ready:")
134
+ puts("\tAge: #{net.age}")
135
+ puts("\tActivation: #{net.activation_function}")
136
+ puts("\tInputs: #{net.num_inputs}")
137
+ puts("\tOutputs: #{net.num_outputs}")
138
+ puts("\tLayers: #{net.num_layers}")
139
+ net.layers.each_with_index do |l, i|
140
+ puts("\t\t#{i}\t#{l.num_inputs}\t#{l.size}\t#{l.class}")
141
+ end
142
+
143
+ $stdout.flush
144
+
145
+ if options.trainer
146
+ num_batches = options.epochs.to_i * training_set.size / trainer_options.batch_size
147
+
148
+ puts("Training in #{num_batches} batches of #{trainer_options.batch_size} examples for #{options.epochs} epochs at a rate of #{trainer_options.learning_rate} with #{trainer.name} using #{trainer_options.cost_function.name}.")
149
+
150
+ bar = CooCoo::ProgressBar.create(:total => num_batches)
151
+
152
+ trainer.train({ network: net,
153
+ data: training_set.each.cycle(options.epochs)
154
+ }.merge(trainer_options.to_h)) do |stats|
155
+ avg_err = stats.average_loss
156
+ raise "Cost went to NAN" if avg_err.nan?
157
+ status = []
158
+ status << "Cost\t#{avg_err.average}\t#{options.verbose ? avg_err : nil}"
159
+ status << "Batch #{stats.batch}/#{num_batches} took #{stats.total_time} seconds"
160
+ if options.model_path
161
+ if options.binary_blob
162
+ File.write_to(options.model_path) do |f|
163
+ f.write(Marshal.dump(net))
164
+ end
165
+ else
166
+ net.save(options.model_path)
167
+ end
168
+ status << "Saved network to #{options.model_path}"
169
+ end
170
+ bar.log(status.join("\n"))
171
+ bar.increment
172
+ end
173
+ end
174
+
175
+ CHECKMARK = "\u2714".green
176
+ CROSSMARK = "\u2718".red
177
+
178
+ if options.num_tests
179
+ puts("Running #{options.num_tests} tests starting from #{options.start_tests_at}:")
180
+ e = training_set.each
181
+ e = e.drop(options.start_tests_at) if options.start_tests_at > 0
182
+ passes = 0
183
+ e.first(options.num_tests).
184
+ each_with_index do |(target, pixels), i|
185
+ out, hs = net.predict(pixels)
186
+ loss = options.test_cost.call(target, out)
187
+ target_max = target.each.with_index.max
188
+ out_max = out.each.with_index.max
189
+ passed = target_max[1] == out_max[1]
190
+ passes += 1 if passed
191
+ puts("#{passed ? CHECKMARK : CROSSMARK} #{options.start_tests_at + i}\t#{loss.average}\t#{target_max}\t#{out[target_max[1]]}\t#{out_max}\t#{target}\t#{out}")
192
+ $stdout.flush
193
+ GC.start
194
+ end
195
+
196
+ puts
197
+ puts("#{passes}/#{options.num_tests} #{passes / options.num_tests.to_f * 100.0}% passed")
198
+ end
@@ -0,0 +1,13 @@
1
+ #!/bin/zsh
2
+
3
+ function run()
4
+ {
5
+ stdbuf -oL $*
6
+ }
7
+
8
+ tee >(run grep -i cost |
9
+ run sed -e 's: :\t:g' |
10
+ run cut -f 2 |
11
+ trend -t "${TREND_TITLE:-Cost}" -c 1a -geometry ${TREND_GEOMETRY:-320x64-0-24} - ${TREND_SIZE:-$((60*4*5))x2} ${TREND_SCALE})
12
+
13
+ [[ -z `jobs -p` ]] || kill $(jobs -p)
@@ -0,0 +1,405 @@
1
+ require 'coo-coo'
2
+
3
+ class InputEncoder
4
+ protected
5
+ def initialize
6
+ end
7
+
8
+ public
9
+ def vector_size
10
+ raise NotImplementedError
11
+ end
12
+
13
+ def encode_input(s)
14
+ raise NotImplementedError
15
+ end
16
+
17
+ def decode_byte(b)
18
+ raise NotImplementedError
19
+ end
20
+
21
+ def decode_output(b)
22
+ raise NotImplementedError
23
+ end
24
+
25
+ def encode_string(s)
26
+ s.bytes.collect { |b| encode_input(b) }
27
+ end
28
+
29
+ def decode_sequence(s)
30
+ s.pack('c*')
31
+ end
32
+
33
+ def decode_to_string(output)
34
+ decode_sequence(output.collect { |v| decode_output(v) })
35
+ end
36
+ end
37
+
38
+ class LittleInputEncoder < InputEncoder
39
+ UA = 'A'.bytes[0]
40
+ UZ = 'Z'.bytes[0]
41
+ LA = 'a'.bytes[0]
42
+ LZ = 'z'.bytes[0]
43
+ N0 = '0'.bytes[0]
44
+ N9 = '9'.bytes[0]
45
+ SPACE = ' '.bytes[0]
46
+
47
+ def vector_size
48
+ 39
49
+ end
50
+
51
+ def encode_byte(b)
52
+ if b >= UA && b <= UZ
53
+ return (b - UA) + 2
54
+ elsif b >= LA && b <= LZ
55
+ return (b - LA) + 2
56
+ elsif b >= N0 && b <= N9
57
+ return (b - N0) + 26 + 2
58
+ elsif b == SPACE
59
+ return 1
60
+ else
61
+ return 0
62
+ end
63
+ end
64
+
65
+ def decode_byte(i)
66
+ if i <= 1
67
+ return SPACE
68
+ elsif i <= 27
69
+ return LA + (i - 2)
70
+ elsif i <= 37
71
+ return N0 + (i - 28)
72
+ else
73
+ return SPACE
74
+ end
75
+ end
76
+
77
+ def encode_input(b)
78
+ v = CooCoo::Vector.zeros(vector_size)
79
+ v[encode_byte(b)] = 1.0
80
+ v
81
+ end
82
+
83
+ def decode_output(v)
84
+ v, i = v.each_with_index.max
85
+ decode_byte(i)
86
+ end
87
+ end
88
+
89
+ class AsciiInputEncoder < InputEncoder
90
+ def vector_size
91
+ 256
92
+ end
93
+
94
+ def encode_input(b)
95
+ $encoded_input_hash ||= Hash.new do |h, k|
96
+ v = CooCoo::Vector.zeros(vector_size)
97
+ v[k] = 1.0
98
+ h[k] = v
99
+ end
100
+
101
+ $encoded_input_hash[b]
102
+ end
103
+
104
+ def decode_output(v)
105
+ $encoded_output_hash ||= Hash.new do |h, k|
106
+ i = k.each_with_index.max[1]
107
+ h[k] = i
108
+ end
109
+
110
+ $encoded_output_hash[v]
111
+ end
112
+ end
113
+
114
+ def training_enumerator(data, sequence_size, encoder)
115
+ Enumerator.new do |yielder|
116
+ data.size.times do |i|
117
+ input = data[i, sequence_size].collect { |e| encoder.encode_input(e || 0) }
118
+ output = data[i + 1, sequence_size].collect { |e| encoder.encode_input(e || 0) }
119
+ yielder << [ CooCoo::Sequence[output], CooCoo::Sequence[input] ]
120
+ end
121
+ # iters = sequence_size.times.collect { |i| data.each.drop(i) }
122
+ # iters[0].zip(*iters.drop(1)).
123
+ # each_with_index do |values, i|
124
+ # input = values[0, values.size - 1].collect { |e| encoder.encode_input(e || 0) }
125
+ # output = values[1, values.size - 1].collect { |e| encoder.encode_input(e || 0) }
126
+ # yielder << [ CooCoo::Sequence[output], CooCoo::Sequence[input] ]
127
+ # end
128
+ end
129
+ end
130
+
131
+ if __FILE__ == $0
132
+ def sample_top(arr, range)
133
+ c = arr.each.with_index.sort[-range, range.abs].collect(&:last)
134
+ c[rand(c.size)]
135
+ end
136
+
137
+ def sample(arr, temperature = 1.0)
138
+ narr = arr.normalize
139
+ picks = (CooCoo::Vector.rand(arr.size) - narr).each.with_index.select { |v, i| v <= 0.0 }.sort
140
+ pick = picks[rand(picks.size) * temperature]
141
+ (pick && pick[1]) || 0
142
+ end
143
+
144
+ require 'ostruct'
145
+
146
+ options = OpenStruct.new
147
+ options.encoder = AsciiInputEncoder.new
148
+ options.recurrent_size = 1024
149
+ options.activation_function = CooCoo::ActivationFunctions.from_name('Logistic')
150
+ options.epochs = 1000
151
+ options.model_path = "char-rnn.coo-coo_model"
152
+ options.input_path = nil
153
+ options.backprop_limit = nil
154
+ options.trainer = nil
155
+ options.sequence_size = 4
156
+ options.num_layers = 1
157
+ options.hidden_size = nil
158
+ options.num_recurrent_layers = 2
159
+ options.softmax = nil
160
+ options.verbose = false
161
+ options.generator = false
162
+ options.generator_temperature = 4
163
+ options.generator_amount = 140
164
+ options.generator_init = "\n"
165
+ options.sampler = method(:sample)
166
+
167
+ opts = CooCoo::OptionParser.new do |o|
168
+ o.on('-v', '--verbose') do
169
+ options.verbose = true
170
+ end
171
+
172
+ o.on('--little') do |v|
173
+ options.encoder = LittleInputEncoder.new
174
+ end
175
+
176
+ o.on('-m', '--model PATH') do |path|
177
+ options.model_path = path
178
+ end
179
+
180
+ o.on('-b', '--binary') do
181
+ options.binary = true
182
+ end
183
+
184
+ o.on('--recurrent-size NUMBER') do |size|
185
+ options.recurrent_size = size.to_i
186
+ end
187
+
188
+ o.on('--activation NAME') do |name|
189
+ options.activation_function = CooCoo::ActivationFunctions.from_name(name)
190
+ end
191
+
192
+ o.on('--epochs NUMBER') do |n|
193
+ options.epochs = n.to_i
194
+ end
195
+
196
+ o.on('--backprop-limit NUMBER') do |n|
197
+ options.backprop_limit = n.to_i
198
+ end
199
+
200
+ o.on('--hidden-size NUMBER') do |n|
201
+ options.hidden_size = n.to_i
202
+ end
203
+
204
+ o.on('--recurrent-layers NUMBER') do |n|
205
+ options.num_recurrent_layers = n.to_i
206
+ end
207
+
208
+ o.on('--softmax') do
209
+ options.softmax = true
210
+ options.cost_function = CooCoo::CostFunctions::CrossEntropy
211
+ end
212
+
213
+ o.on('-p', '--predict') do
214
+ options.trainer = nil
215
+ end
216
+
217
+ o.on('-t', '--trainer NAME') do |name|
218
+ options.trainer = CooCoo::Trainer.from_name(name)
219
+ end
220
+
221
+ o.on('-n', '--sequence-size NUMBER') do |n|
222
+ n = n.to_i
223
+ raise ArgumentError.new("sequence-size must be > 0") if n <= 0
224
+ options.sequence_size = n
225
+ end
226
+
227
+ o.on('--layers NUMBER') do |n|
228
+ n = n.to_i
229
+ raise ArgumentError.new("number of layers must be > 0") if n <= 0
230
+ options.num_layers = n
231
+ end
232
+
233
+ o.on('-g', '--generate AMOUNT') do |v|
234
+ options.generator = true
235
+ options.generator_amount = v.to_i
236
+ end
237
+
238
+ o.on('--generator-init STRING') do |v|
239
+ options.generator_init = v
240
+ end
241
+
242
+ o.on('--generator-temp NUMBER') do |v|
243
+ options.generator_temperature = v.to_i
244
+ end
245
+
246
+ o.on('--generator-sample-top') do
247
+ options.sampler = method(:sample_top)
248
+ end
249
+
250
+ o.on('--seed NUMBER') do |v|
251
+ srand(v.to_i)
252
+ end
253
+
254
+ o.on('-h', '--help') do
255
+ puts(o)
256
+ if options.trainer
257
+ t_opts, _ = options.trainer.options
258
+ puts(t_opts)
259
+ end
260
+
261
+ exit
262
+ end
263
+ end
264
+
265
+ argv = opts.parse!(ARGV)
266
+
267
+ if options.trainer
268
+ t_opts, trainer_options = options.trainer.options
269
+ argv = t_opts.parse!(argv)
270
+ end
271
+
272
+ options.input_path = argv[0]
273
+ encoder = options.encoder
274
+ options.hidden_size ||= encoder.vector_size
275
+
276
+ if File.exists?(options.model_path)
277
+ $stdout.print("Loading #{options.model_path}...")
278
+ $stdout.flush
279
+ net = if options.binary
280
+ Marshal.load(File.read(options.model_path))
281
+ else
282
+ CooCoo::TemporalNetwork.new(network: CooCoo::Network.load(options.model_path))
283
+ end
284
+ puts("\rLoaded #{options.model_path}:")
285
+ else
286
+ puts("Creating new network")
287
+ puts("\tNumber of inputs: #{encoder.vector_size}")
288
+ puts("\tNumber of layers: #{options.num_layers}")
289
+ puts("\tHidden size: #{options.hidden_size}#{' with mix' if options.hidden_size != encoder.vector_size}")
290
+ puts("\tRecurrent size: #{options.recurrent_size}")
291
+ puts("\tActivation: #{options.activation_function}")
292
+ puts("\tRecurrent layers: #{options.num_recurrent_layers}")
293
+
294
+ net = CooCoo::TemporalNetwork.new()
295
+ if options.hidden_size != encoder.vector_size
296
+ net.layer(CooCoo::Layer.new(encoder.vector_size, options.hidden_size, options.activation_function))
297
+ end
298
+
299
+ options.num_recurrent_layers.to_i.times do |n|
300
+ rec = CooCoo::Recurrence::Frontend.new(options.hidden_size, options.recurrent_size)
301
+ net.layer(rec)
302
+ options.num_layers.times do
303
+ net.layer(CooCoo::Layer.new(options.hidden_size + rec.recurrent_size, options.hidden_size + rec.recurrent_size, options.activation_function))
304
+ end
305
+
306
+ net.layer(rec.backend)
307
+ #net.layer(CooCoo::Layer.new(options.hidden_size, options.hidden_size, options.activation_function))
308
+ end
309
+
310
+ if options.hidden_size != encoder.vector_size
311
+ net.layer(CooCoo::Layer.new(options.hidden_size, encoder.vector_size, options.activation_function))
312
+ end
313
+
314
+ if options.softmax
315
+ net.layer(CooCoo::LinearLayer.new(encoder.vector_size, CooCoo::ActivationFunctions.from_name('ShiftedSoftMax')))
316
+ end
317
+ end
318
+
319
+ net.backprop_limit = options.backprop_limit
320
+
321
+ puts("\tAge: #{net.age}")
322
+ puts("\tInputs: #{net.num_inputs}")
323
+ puts("\tOutputs: #{net.num_outputs}")
324
+ puts("\tLayers: #{net.num_layers}")
325
+
326
+ data = if options.input_path
327
+ puts("Reading #{options.input_path}")
328
+ File.read(options.input_path)
329
+ else
330
+ puts("Reading stdin...")
331
+ $stdin.read
332
+ end
333
+ data = data.bytes
334
+ puts("Read #{data.size} bytes")
335
+
336
+ puts("Splitting into #{options.sequence_size} byte sequences.")
337
+ training_data = training_enumerator(data, options.sequence_size, encoder)
338
+
339
+ if options.trainer
340
+ puts("Training on #{data.size} bytes from #{options.input_path || "stdin"} in #{options.epochs} epochs in batches of #{trainer_options.batch_size} at a learning rate of #{trainer_options.learning_rate}...")
341
+
342
+ trainer = options.trainer
343
+ bar = CooCoo::ProgressBar.create(:total => (options.epochs * data.size / trainer_options.batch_size.to_f).ceil)
344
+ trainer.train({ network: net,
345
+ data: training_data.cycle(options.epochs),
346
+ }.merge(trainer_options.to_h)) do |stats|
347
+ cost = stats.average_loss.average
348
+ raise 'Cost went to NAN' if cost.nan?
349
+ status = [ "Cost #{cost.average} #{options.verbose ? cost : ''}" ]
350
+
351
+ File.write_to(options.model_path) do |f|
352
+ if options.binary
353
+ f.puts(Marshal.dump(net))
354
+ else
355
+ f.puts(net.to_hash.to_yaml)
356
+ end
357
+ end
358
+ status << "Saved to #{options.model_path}"
359
+ bar.log(status.join("\n"))
360
+ bar.increment
361
+ end
362
+ end
363
+
364
+ if options.generator
365
+ o, hidden_state = net.predict(encoder.encode_string(options.generator_init), {})
366
+ data.each do |b|
367
+ o, hidden_state = net.predict(encoder.encode_input(b), hidden_state)
368
+ end
369
+ options.generator_amount.to_i.times do |n|
370
+ c = options.sampler.call(o, options.generator_temperature)
371
+ c = encoder.decode_byte(c) if encoder.vector_size != 256
372
+ $stdout.write(c.chr)
373
+ $stdout.flush
374
+ o, hidden_state = net.predict(encoder.encode_input(c), hidden_state)
375
+ end
376
+
377
+ $stdout.puts
378
+ else
379
+ puts("Predicting:")
380
+ hidden_state = nil
381
+ s = data.size.times.collect do |i|
382
+ input = data[i, options.sequence_size].collect { |e| encoder.encode_input(e || 0) }
383
+ output, hidden_state = net.predict(input, hidden_state)
384
+ output.collect { |b| encoder.decode_output(b) }
385
+ end
386
+
387
+ s.each_with_index do |c, i|
388
+ input = data[i, options.sequence_size]
389
+ puts("#{i} #{input.inspect} -> #{c.inspect}\t#{encoder.decode_sequence(input)} -> #{encoder.decode_sequence(c)}")
390
+ end
391
+
392
+ puts(encoder.decode_sequence(s.collect(&:first)))
393
+ puts(encoder.decode_sequence(s.collect(&:last)))
394
+
395
+ hidden_state = nil
396
+ c = data[rand(data.size)]
397
+ s = data.size.times.collect do |i|
398
+ o, hidden_state = net.predict(encoder.encode_input(c), hidden_state)
399
+ c = encoder.decode_output(o)
400
+ end
401
+
402
+ puts(s.inspect)
403
+ puts(encoder.decode_sequence(s))
404
+ end
405
+ end