CooCoo 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (105) hide show
  1. checksums.yaml +7 -0
  2. data/.gitignore +16 -0
  3. data/CooCoo.gemspec +47 -0
  4. data/Gemfile +4 -0
  5. data/Gemfile.lock +88 -0
  6. data/README.md +123 -0
  7. data/Rakefile +81 -0
  8. data/bin/cuda-dev-info +25 -0
  9. data/bin/cuda-free +28 -0
  10. data/bin/cuda-free-trend +7 -0
  11. data/bin/ffi-gen +267 -0
  12. data/bin/spec_runner_html.sh +42 -0
  13. data/bin/trainer +198 -0
  14. data/bin/trend-cost +13 -0
  15. data/examples/char-rnn.rb +405 -0
  16. data/examples/cifar/cifar.rb +94 -0
  17. data/examples/img-similarity.rb +201 -0
  18. data/examples/math_ops.rb +57 -0
  19. data/examples/mnist.rb +365 -0
  20. data/examples/mnist_classifier.rb +293 -0
  21. data/examples/mnist_dream.rb +214 -0
  22. data/examples/seeds.rb +268 -0
  23. data/examples/seeds_dataset.txt +210 -0
  24. data/examples/t10k-images-idx3-ubyte +0 -0
  25. data/examples/t10k-labels-idx1-ubyte +0 -0
  26. data/examples/train-images-idx3-ubyte +0 -0
  27. data/examples/train-labels-idx1-ubyte +0 -0
  28. data/ext/buffer/Rakefile +50 -0
  29. data/ext/buffer/buffer.pre.cu +727 -0
  30. data/ext/buffer/matrix.pre.cu +49 -0
  31. data/lib/CooCoo.rb +1 -0
  32. data/lib/coo-coo.rb +18 -0
  33. data/lib/coo-coo/activation_functions.rb +344 -0
  34. data/lib/coo-coo/consts.rb +5 -0
  35. data/lib/coo-coo/convolution.rb +298 -0
  36. data/lib/coo-coo/core_ext.rb +75 -0
  37. data/lib/coo-coo/cost_functions.rb +91 -0
  38. data/lib/coo-coo/cuda.rb +116 -0
  39. data/lib/coo-coo/cuda/device_buffer.rb +240 -0
  40. data/lib/coo-coo/cuda/device_buffer/ffi.rb +109 -0
  41. data/lib/coo-coo/cuda/error.rb +51 -0
  42. data/lib/coo-coo/cuda/host_buffer.rb +117 -0
  43. data/lib/coo-coo/cuda/runtime.rb +157 -0
  44. data/lib/coo-coo/cuda/vector.rb +315 -0
  45. data/lib/coo-coo/data_sources.rb +2 -0
  46. data/lib/coo-coo/data_sources/xournal.rb +25 -0
  47. data/lib/coo-coo/data_sources/xournal/bitmap_stream.rb +197 -0
  48. data/lib/coo-coo/data_sources/xournal/document.rb +377 -0
  49. data/lib/coo-coo/data_sources/xournal/loader.rb +144 -0
  50. data/lib/coo-coo/data_sources/xournal/renderer.rb +101 -0
  51. data/lib/coo-coo/data_sources/xournal/saver.rb +99 -0
  52. data/lib/coo-coo/data_sources/xournal/training_document.rb +78 -0
  53. data/lib/coo-coo/data_sources/xournal/training_document/constants.rb +15 -0
  54. data/lib/coo-coo/data_sources/xournal/training_document/document_maker.rb +89 -0
  55. data/lib/coo-coo/data_sources/xournal/training_document/document_reader.rb +105 -0
  56. data/lib/coo-coo/data_sources/xournal/training_document/example.rb +37 -0
  57. data/lib/coo-coo/data_sources/xournal/training_document/sets.rb +76 -0
  58. data/lib/coo-coo/debug.rb +8 -0
  59. data/lib/coo-coo/dot.rb +129 -0
  60. data/lib/coo-coo/drawing.rb +4 -0
  61. data/lib/coo-coo/drawing/cairo_canvas.rb +100 -0
  62. data/lib/coo-coo/drawing/canvas.rb +68 -0
  63. data/lib/coo-coo/drawing/chunky_canvas.rb +101 -0
  64. data/lib/coo-coo/drawing/sixel.rb +214 -0
  65. data/lib/coo-coo/enum.rb +17 -0
  66. data/lib/coo-coo/from_name.rb +58 -0
  67. data/lib/coo-coo/fully_connected_layer.rb +205 -0
  68. data/lib/coo-coo/generation_script.rb +38 -0
  69. data/lib/coo-coo/grapher.rb +140 -0
  70. data/lib/coo-coo/image.rb +286 -0
  71. data/lib/coo-coo/layer.rb +67 -0
  72. data/lib/coo-coo/layer_factory.rb +26 -0
  73. data/lib/coo-coo/linear_layer.rb +59 -0
  74. data/lib/coo-coo/math.rb +607 -0
  75. data/lib/coo-coo/math/abstract_vector.rb +121 -0
  76. data/lib/coo-coo/math/functions.rb +39 -0
  77. data/lib/coo-coo/math/interpolation.rb +7 -0
  78. data/lib/coo-coo/network.rb +264 -0
  79. data/lib/coo-coo/neuron.rb +112 -0
  80. data/lib/coo-coo/neuron_layer.rb +168 -0
  81. data/lib/coo-coo/option_parser.rb +18 -0
  82. data/lib/coo-coo/platform.rb +17 -0
  83. data/lib/coo-coo/progress_bar.rb +11 -0
  84. data/lib/coo-coo/recurrence/backend.rb +99 -0
  85. data/lib/coo-coo/recurrence/frontend.rb +101 -0
  86. data/lib/coo-coo/sequence.rb +187 -0
  87. data/lib/coo-coo/shell.rb +2 -0
  88. data/lib/coo-coo/temporal_network.rb +291 -0
  89. data/lib/coo-coo/trainer.rb +21 -0
  90. data/lib/coo-coo/trainer/base.rb +67 -0
  91. data/lib/coo-coo/trainer/batch.rb +82 -0
  92. data/lib/coo-coo/trainer/batch_stats.rb +27 -0
  93. data/lib/coo-coo/trainer/momentum_stochastic.rb +59 -0
  94. data/lib/coo-coo/trainer/stochastic.rb +47 -0
  95. data/lib/coo-coo/transformer.rb +272 -0
  96. data/lib/coo-coo/vector_layer.rb +194 -0
  97. data/lib/coo-coo/version.rb +3 -0
  98. data/lib/coo-coo/weight_deltas.rb +23 -0
  99. data/prototypes/convolution.rb +116 -0
  100. data/prototypes/linear_drop.rb +51 -0
  101. data/prototypes/recurrent_layers.rb +79 -0
  102. data/www/images/screamer.png +0 -0
  103. data/www/images/screamer.xcf +0 -0
  104. data/www/index.html +82 -0
  105. metadata +373 -0
@@ -0,0 +1,42 @@
1
+ #!/bin/zsh
2
+
3
+ INDEX=doc/spec/index.html
4
+
5
+ # Start the index.html
6
+ cat <<EOT > $INDEX
7
+ <html>
8
+ <head>
9
+ <style>.PASSED { color: green; } .FAILED { color: red; }</style>
10
+ </head>
11
+ <body>
12
+ <ul>
13
+ EOT
14
+
15
+ # The spec runner
16
+ function run_spec()
17
+ {
18
+ local spec="$1"
19
+ local DIR=`dirname $spec`
20
+ local BASE=`basename -s .spec $spec`
21
+ echo -n "$spec "
22
+
23
+ mkdir -p doc/$DIR && bundle exec rspec -Ilib -Iexamples $spec -f html > doc/$DIR/$BASE.html && STATE=PASSED || STATE=FAILED
24
+
25
+ echo $STATE
26
+
27
+ cat <<-EOT >> $INDEX
28
+ <li class="$STATE"><a href="../$DIR/$BASE.html">$DIR/$BASE</a> &mdash; <span class="state">$STATE</span></li>
29
+ EOT
30
+ }
31
+
32
+ # And the loop
33
+ for spec in `find spec -name \*.spec | sort`; do
34
+ run_spec "$spec"
35
+ done
36
+
37
+ # Finalize the index
38
+ cat <<EOT >> $INDEX
39
+ </ul>
40
+ </body>
41
+ </html>
42
+ EOT
@@ -0,0 +1,198 @@
1
+ #!/bin/env ruby
2
+
3
+ ROOT = File.dirname(File.dirname(__FILE__))
4
+ $: << File.join(ROOT, 'lib') << File.join(ROOT, 'examples')
5
+
6
+ require 'colorize'
7
+ require 'coo-coo'
8
+
9
+ def load_script_opts(path)
10
+ CooCoo::GenerationScript.
11
+ new(path, $stdout).
12
+ opts
13
+ end
14
+
15
+ require 'pry'
16
+ $pry = {
17
+ binding: binding,
18
+ main_thread: Thread.current
19
+ }
20
+
21
+ Signal.trap("USR1") do
22
+ $pry[:thread] ||= Thread.new do
23
+ $pry[:binding].pry
24
+ end
25
+ end
26
+
27
+ require 'ostruct'
28
+ options = OpenStruct.new
29
+ options.epochs = 1000
30
+ options.dataset = nil
31
+ options.model_path = nil
32
+ options.prototype = nil
33
+ options.trainer = nil
34
+ options.num_tests = 20
35
+ options.start_tests_at = 0
36
+ options.test_cost = CooCoo::CostFunctions::MeanSquare
37
+
38
+ opts = CooCoo::OptionParser.new do |o|
39
+ o.on('-d', '--dataset PATH') do |path|
40
+ options.dataset = path
41
+ end
42
+
43
+ o.on('-m', '--model PATH') do |path|
44
+ options.model_path = Pathname.new(path)
45
+ options.binary_blob = options.binary_blob || File.extname(options.model_path) == '.bin'
46
+ end
47
+
48
+ o.on('--prototype PATH') do |path|
49
+ options.prototype = path
50
+ end
51
+
52
+ o.on('--binary') do
53
+ options.binary_blob = true
54
+ end
55
+
56
+ o.on('-e', '--train NUMBER') do |n|
57
+ options.epochs = n.to_i
58
+ end
59
+
60
+ o.on('-p', '--predict NUMBER') do |n|
61
+ options.num_tests = n.to_i
62
+ end
63
+
64
+ o.on('-s', '--skip NUMBER') do |n|
65
+ options.start_tests_at = n.to_i
66
+ end
67
+
68
+ o.on('-t', '--trainer NAME') do |name|
69
+ options.trainer = name
70
+ end
71
+
72
+ o.on('--test-cost NAME') do |name|
73
+ options.test_cost = CooCoo::CostFunctions.from_name(name)
74
+ end
75
+
76
+ o.on('--verbose') do
77
+ options.verbose = true
78
+ end
79
+
80
+ o.on('-h', '--help') do
81
+ puts(opts)
82
+ if options.dataset
83
+ opts = load_script_opts(options.dataset.to_s)
84
+ puts(opts)
85
+ end
86
+ if options.prototype
87
+ opts = load_script_opts(options.prototype.to_s)
88
+ puts(opts)
89
+ end
90
+ if options.trainer
91
+ t = CooCoo::Trainer.from_name(options.trainer)
92
+ raise ArgumentError.new("Unknown trainer #{options.trainer}") unless t
93
+ opts, _ = t.options
94
+ puts(opts)
95
+ end
96
+ exit
97
+ end
98
+ end
99
+
100
+ argv = opts.parse!(ARGV)
101
+
102
+ puts("Loading training set #{options.dataset}")
103
+
104
+ training_set_gen = CooCoo::GenerationScript.new(options.dataset.to_s, $stdout)
105
+ argv, training_set = training_set_gen.call(argv)
106
+
107
+ if options.model_path && File.exists?(options.model_path.to_s)
108
+ puts("Loading network #{options.model_path}")
109
+ if options.binary_blob
110
+ net = Marshal.load(File.read(options.model_path))
111
+ else
112
+ net = CooCoo::Network.load(options.model_path)
113
+ end
114
+ else
115
+ puts("Generating network from #{options.prototype}")
116
+ net_gen = CooCoo::GenerationScript.new(options.prototype.to_s, $stdout)
117
+ argv, net = net_gen.call(argv, training_set)
118
+ end
119
+
120
+ trainer = nil
121
+ trainer_options = nil
122
+
123
+ if options.trainer
124
+ trainer = CooCoo::Trainer.from_name(options.trainer)
125
+ t_opts, trainer_options = trainer.options
126
+ argv = t_opts.parse!(argv)
127
+ end
128
+
129
+ unless argv.empty?
130
+ raise ArgumentError.new("Unknown arguments: #{argv.inspect}")
131
+ end
132
+
133
+ puts("Net ready:")
134
+ puts("\tAge: #{net.age}")
135
+ puts("\tActivation: #{net.activation_function}")
136
+ puts("\tInputs: #{net.num_inputs}")
137
+ puts("\tOutputs: #{net.num_outputs}")
138
+ puts("\tLayers: #{net.num_layers}")
139
+ net.layers.each_with_index do |l, i|
140
+ puts("\t\t#{i}\t#{l.num_inputs}\t#{l.size}\t#{l.class}")
141
+ end
142
+
143
+ $stdout.flush
144
+
145
+ if options.trainer
146
+ num_batches = options.epochs.to_i * training_set.size / trainer_options.batch_size
147
+
148
+ puts("Training in #{num_batches} batches of #{trainer_options.batch_size} examples for #{options.epochs} epochs at a rate of #{trainer_options.learning_rate} with #{trainer.name} using #{trainer_options.cost_function.name}.")
149
+
150
+ bar = CooCoo::ProgressBar.create(:total => num_batches)
151
+
152
+ trainer.train({ network: net,
153
+ data: training_set.each.cycle(options.epochs)
154
+ }.merge(trainer_options.to_h)) do |stats|
155
+ avg_err = stats.average_loss
156
+ raise "Cost went to NAN" if avg_err.nan?
157
+ status = []
158
+ status << "Cost\t#{avg_err.average}\t#{options.verbose ? avg_err : nil}"
159
+ status << "Batch #{stats.batch}/#{num_batches} took #{stats.total_time} seconds"
160
+ if options.model_path
161
+ if options.binary_blob
162
+ File.write_to(options.model_path) do |f|
163
+ f.write(Marshal.dump(net))
164
+ end
165
+ else
166
+ net.save(options.model_path)
167
+ end
168
+ status << "Saved network to #{options.model_path}"
169
+ end
170
+ bar.log(status.join("\n"))
171
+ bar.increment
172
+ end
173
+ end
174
+
175
+ CHECKMARK = "\u2714".green
176
+ CROSSMARK = "\u2718".red
177
+
178
+ if options.num_tests
179
+ puts("Running #{options.num_tests} tests starting from #{options.start_tests_at}:")
180
+ e = training_set.each
181
+ e = e.drop(options.start_tests_at) if options.start_tests_at > 0
182
+ passes = 0
183
+ e.first(options.num_tests).
184
+ each_with_index do |(target, pixels), i|
185
+ out, hs = net.predict(pixels)
186
+ loss = options.test_cost.call(target, out)
187
+ target_max = target.each.with_index.max
188
+ out_max = out.each.with_index.max
189
+ passed = target_max[1] == out_max[1]
190
+ passes += 1 if passed
191
+ puts("#{passed ? CHECKMARK : CROSSMARK} #{options.start_tests_at + i}\t#{loss.average}\t#{target_max}\t#{out[target_max[1]]}\t#{out_max}\t#{target}\t#{out}")
192
+ $stdout.flush
193
+ GC.start
194
+ end
195
+
196
+ puts
197
+ puts("#{passes}/#{options.num_tests} #{passes / options.num_tests.to_f * 100.0}% passed")
198
+ end
@@ -0,0 +1,13 @@
1
+ #!/bin/zsh
2
+
3
+ function run()
4
+ {
5
+ stdbuf -oL $*
6
+ }
7
+
8
+ tee >(run grep -i cost |
9
+ run sed -e 's: :\t:g' |
10
+ run cut -f 2 |
11
+ trend -t "${TREND_TITLE:-Cost}" -c 1a -geometry ${TREND_GEOMETRY:-320x64-0-24} - ${TREND_SIZE:-$((60*4*5))x2} ${TREND_SCALE})
12
+
13
+ [[ -z `jobs -p` ]] || kill $(jobs -p)
@@ -0,0 +1,405 @@
1
+ require 'coo-coo'
2
+
3
+ class InputEncoder
4
+ protected
5
+ def initialize
6
+ end
7
+
8
+ public
9
+ def vector_size
10
+ raise NotImplementedError
11
+ end
12
+
13
+ def encode_input(s)
14
+ raise NotImplementedError
15
+ end
16
+
17
+ def decode_byte(b)
18
+ raise NotImplementedError
19
+ end
20
+
21
+ def decode_output(b)
22
+ raise NotImplementedError
23
+ end
24
+
25
+ def encode_string(s)
26
+ s.bytes.collect { |b| encode_input(b) }
27
+ end
28
+
29
+ def decode_sequence(s)
30
+ s.pack('c*')
31
+ end
32
+
33
+ def decode_to_string(output)
34
+ decode_sequence(output.collect { |v| decode_output(v) })
35
+ end
36
+ end
37
+
38
+ class LittleInputEncoder < InputEncoder
39
+ UA = 'A'.bytes[0]
40
+ UZ = 'Z'.bytes[0]
41
+ LA = 'a'.bytes[0]
42
+ LZ = 'z'.bytes[0]
43
+ N0 = '0'.bytes[0]
44
+ N9 = '9'.bytes[0]
45
+ SPACE = ' '.bytes[0]
46
+
47
+ def vector_size
48
+ 39
49
+ end
50
+
51
+ def encode_byte(b)
52
+ if b >= UA && b <= UZ
53
+ return (b - UA) + 2
54
+ elsif b >= LA && b <= LZ
55
+ return (b - LA) + 2
56
+ elsif b >= N0 && b <= N9
57
+ return (b - N0) + 26 + 2
58
+ elsif b == SPACE
59
+ return 1
60
+ else
61
+ return 0
62
+ end
63
+ end
64
+
65
+ def decode_byte(i)
66
+ if i <= 1
67
+ return SPACE
68
+ elsif i <= 27
69
+ return LA + (i - 2)
70
+ elsif i <= 37
71
+ return N0 + (i - 28)
72
+ else
73
+ return SPACE
74
+ end
75
+ end
76
+
77
+ def encode_input(b)
78
+ v = CooCoo::Vector.zeros(vector_size)
79
+ v[encode_byte(b)] = 1.0
80
+ v
81
+ end
82
+
83
+ def decode_output(v)
84
+ v, i = v.each_with_index.max
85
+ decode_byte(i)
86
+ end
87
+ end
88
+
89
+ class AsciiInputEncoder < InputEncoder
90
+ def vector_size
91
+ 256
92
+ end
93
+
94
+ def encode_input(b)
95
+ $encoded_input_hash ||= Hash.new do |h, k|
96
+ v = CooCoo::Vector.zeros(vector_size)
97
+ v[k] = 1.0
98
+ h[k] = v
99
+ end
100
+
101
+ $encoded_input_hash[b]
102
+ end
103
+
104
+ def decode_output(v)
105
+ $encoded_output_hash ||= Hash.new do |h, k|
106
+ i = k.each_with_index.max[1]
107
+ h[k] = i
108
+ end
109
+
110
+ $encoded_output_hash[v]
111
+ end
112
+ end
113
+
114
+ def training_enumerator(data, sequence_size, encoder)
115
+ Enumerator.new do |yielder|
116
+ data.size.times do |i|
117
+ input = data[i, sequence_size].collect { |e| encoder.encode_input(e || 0) }
118
+ output = data[i + 1, sequence_size].collect { |e| encoder.encode_input(e || 0) }
119
+ yielder << [ CooCoo::Sequence[output], CooCoo::Sequence[input] ]
120
+ end
121
+ # iters = sequence_size.times.collect { |i| data.each.drop(i) }
122
+ # iters[0].zip(*iters.drop(1)).
123
+ # each_with_index do |values, i|
124
+ # input = values[0, values.size - 1].collect { |e| encoder.encode_input(e || 0) }
125
+ # output = values[1, values.size - 1].collect { |e| encoder.encode_input(e || 0) }
126
+ # yielder << [ CooCoo::Sequence[output], CooCoo::Sequence[input] ]
127
+ # end
128
+ end
129
+ end
130
+
131
+ if __FILE__ == $0
132
+ def sample_top(arr, range)
133
+ c = arr.each.with_index.sort[-range, range.abs].collect(&:last)
134
+ c[rand(c.size)]
135
+ end
136
+
137
+ def sample(arr, temperature = 1.0)
138
+ narr = arr.normalize
139
+ picks = (CooCoo::Vector.rand(arr.size) - narr).each.with_index.select { |v, i| v <= 0.0 }.sort
140
+ pick = picks[rand(picks.size) * temperature]
141
+ (pick && pick[1]) || 0
142
+ end
143
+
144
+ require 'ostruct'
145
+
146
+ options = OpenStruct.new
147
+ options.encoder = AsciiInputEncoder.new
148
+ options.recurrent_size = 1024
149
+ options.activation_function = CooCoo::ActivationFunctions.from_name('Logistic')
150
+ options.epochs = 1000
151
+ options.model_path = "char-rnn.coo-coo_model"
152
+ options.input_path = nil
153
+ options.backprop_limit = nil
154
+ options.trainer = nil
155
+ options.sequence_size = 4
156
+ options.num_layers = 1
157
+ options.hidden_size = nil
158
+ options.num_recurrent_layers = 2
159
+ options.softmax = nil
160
+ options.verbose = false
161
+ options.generator = false
162
+ options.generator_temperature = 4
163
+ options.generator_amount = 140
164
+ options.generator_init = "\n"
165
+ options.sampler = method(:sample)
166
+
167
+ opts = CooCoo::OptionParser.new do |o|
168
+ o.on('-v', '--verbose') do
169
+ options.verbose = true
170
+ end
171
+
172
+ o.on('--little') do |v|
173
+ options.encoder = LittleInputEncoder.new
174
+ end
175
+
176
+ o.on('-m', '--model PATH') do |path|
177
+ options.model_path = path
178
+ end
179
+
180
+ o.on('-b', '--binary') do
181
+ options.binary = true
182
+ end
183
+
184
+ o.on('--recurrent-size NUMBER') do |size|
185
+ options.recurrent_size = size.to_i
186
+ end
187
+
188
+ o.on('--activation NAME') do |name|
189
+ options.activation_function = CooCoo::ActivationFunctions.from_name(name)
190
+ end
191
+
192
+ o.on('--epochs NUMBER') do |n|
193
+ options.epochs = n.to_i
194
+ end
195
+
196
+ o.on('--backprop-limit NUMBER') do |n|
197
+ options.backprop_limit = n.to_i
198
+ end
199
+
200
+ o.on('--hidden-size NUMBER') do |n|
201
+ options.hidden_size = n.to_i
202
+ end
203
+
204
+ o.on('--recurrent-layers NUMBER') do |n|
205
+ options.num_recurrent_layers = n.to_i
206
+ end
207
+
208
+ o.on('--softmax') do
209
+ options.softmax = true
210
+ options.cost_function = CooCoo::CostFunctions::CrossEntropy
211
+ end
212
+
213
+ o.on('-p', '--predict') do
214
+ options.trainer = nil
215
+ end
216
+
217
+ o.on('-t', '--trainer NAME') do |name|
218
+ options.trainer = CooCoo::Trainer.from_name(name)
219
+ end
220
+
221
+ o.on('-n', '--sequence-size NUMBER') do |n|
222
+ n = n.to_i
223
+ raise ArgumentError.new("sequence-size must be > 0") if n <= 0
224
+ options.sequence_size = n
225
+ end
226
+
227
+ o.on('--layers NUMBER') do |n|
228
+ n = n.to_i
229
+ raise ArgumentError.new("number of layers must be > 0") if n <= 0
230
+ options.num_layers = n
231
+ end
232
+
233
+ o.on('-g', '--generate AMOUNT') do |v|
234
+ options.generator = true
235
+ options.generator_amount = v.to_i
236
+ end
237
+
238
+ o.on('--generator-init STRING') do |v|
239
+ options.generator_init = v
240
+ end
241
+
242
+ o.on('--generator-temp NUMBER') do |v|
243
+ options.generator_temperature = v.to_i
244
+ end
245
+
246
+ o.on('--generator-sample-top') do
247
+ options.sampler = method(:sample_top)
248
+ end
249
+
250
+ o.on('--seed NUMBER') do |v|
251
+ srand(v.to_i)
252
+ end
253
+
254
+ o.on('-h', '--help') do
255
+ puts(o)
256
+ if options.trainer
257
+ t_opts, _ = options.trainer.options
258
+ puts(t_opts)
259
+ end
260
+
261
+ exit
262
+ end
263
+ end
264
+
265
+ argv = opts.parse!(ARGV)
266
+
267
+ if options.trainer
268
+ t_opts, trainer_options = options.trainer.options
269
+ argv = t_opts.parse!(argv)
270
+ end
271
+
272
+ options.input_path = argv[0]
273
+ encoder = options.encoder
274
+ options.hidden_size ||= encoder.vector_size
275
+
276
+ if File.exists?(options.model_path)
277
+ $stdout.print("Loading #{options.model_path}...")
278
+ $stdout.flush
279
+ net = if options.binary
280
+ Marshal.load(File.read(options.model_path))
281
+ else
282
+ CooCoo::TemporalNetwork.new(network: CooCoo::Network.load(options.model_path))
283
+ end
284
+ puts("\rLoaded #{options.model_path}:")
285
+ else
286
+ puts("Creating new network")
287
+ puts("\tNumber of inputs: #{encoder.vector_size}")
288
+ puts("\tNumber of layers: #{options.num_layers}")
289
+ puts("\tHidden size: #{options.hidden_size}#{' with mix' if options.hidden_size != encoder.vector_size}")
290
+ puts("\tRecurrent size: #{options.recurrent_size}")
291
+ puts("\tActivation: #{options.activation_function}")
292
+ puts("\tRecurrent layers: #{options.num_recurrent_layers}")
293
+
294
+ net = CooCoo::TemporalNetwork.new()
295
+ if options.hidden_size != encoder.vector_size
296
+ net.layer(CooCoo::Layer.new(encoder.vector_size, options.hidden_size, options.activation_function))
297
+ end
298
+
299
+ options.num_recurrent_layers.to_i.times do |n|
300
+ rec = CooCoo::Recurrence::Frontend.new(options.hidden_size, options.recurrent_size)
301
+ net.layer(rec)
302
+ options.num_layers.times do
303
+ net.layer(CooCoo::Layer.new(options.hidden_size + rec.recurrent_size, options.hidden_size + rec.recurrent_size, options.activation_function))
304
+ end
305
+
306
+ net.layer(rec.backend)
307
+ #net.layer(CooCoo::Layer.new(options.hidden_size, options.hidden_size, options.activation_function))
308
+ end
309
+
310
+ if options.hidden_size != encoder.vector_size
311
+ net.layer(CooCoo::Layer.new(options.hidden_size, encoder.vector_size, options.activation_function))
312
+ end
313
+
314
+ if options.softmax
315
+ net.layer(CooCoo::LinearLayer.new(encoder.vector_size, CooCoo::ActivationFunctions.from_name('ShiftedSoftMax')))
316
+ end
317
+ end
318
+
319
+ net.backprop_limit = options.backprop_limit
320
+
321
+ puts("\tAge: #{net.age}")
322
+ puts("\tInputs: #{net.num_inputs}")
323
+ puts("\tOutputs: #{net.num_outputs}")
324
+ puts("\tLayers: #{net.num_layers}")
325
+
326
+ data = if options.input_path
327
+ puts("Reading #{options.input_path}")
328
+ File.read(options.input_path)
329
+ else
330
+ puts("Reading stdin...")
331
+ $stdin.read
332
+ end
333
+ data = data.bytes
334
+ puts("Read #{data.size} bytes")
335
+
336
+ puts("Splitting into #{options.sequence_size} byte sequences.")
337
+ training_data = training_enumerator(data, options.sequence_size, encoder)
338
+
339
+ if options.trainer
340
+ puts("Training on #{data.size} bytes from #{options.input_path || "stdin"} in #{options.epochs} epochs in batches of #{trainer_options.batch_size} at a learning rate of #{trainer_options.learning_rate}...")
341
+
342
+ trainer = options.trainer
343
+ bar = CooCoo::ProgressBar.create(:total => (options.epochs * data.size / trainer_options.batch_size.to_f).ceil)
344
+ trainer.train({ network: net,
345
+ data: training_data.cycle(options.epochs),
346
+ }.merge(trainer_options.to_h)) do |stats|
347
+ cost = stats.average_loss.average
348
+ raise 'Cost went to NAN' if cost.nan?
349
+ status = [ "Cost #{cost.average} #{options.verbose ? cost : ''}" ]
350
+
351
+ File.write_to(options.model_path) do |f|
352
+ if options.binary
353
+ f.puts(Marshal.dump(net))
354
+ else
355
+ f.puts(net.to_hash.to_yaml)
356
+ end
357
+ end
358
+ status << "Saved to #{options.model_path}"
359
+ bar.log(status.join("\n"))
360
+ bar.increment
361
+ end
362
+ end
363
+
364
+ if options.generator
365
+ o, hidden_state = net.predict(encoder.encode_string(options.generator_init), {})
366
+ data.each do |b|
367
+ o, hidden_state = net.predict(encoder.encode_input(b), hidden_state)
368
+ end
369
+ options.generator_amount.to_i.times do |n|
370
+ c = options.sampler.call(o, options.generator_temperature)
371
+ c = encoder.decode_byte(c) if encoder.vector_size != 256
372
+ $stdout.write(c.chr)
373
+ $stdout.flush
374
+ o, hidden_state = net.predict(encoder.encode_input(c), hidden_state)
375
+ end
376
+
377
+ $stdout.puts
378
+ else
379
+ puts("Predicting:")
380
+ hidden_state = nil
381
+ s = data.size.times.collect do |i|
382
+ input = data[i, options.sequence_size].collect { |e| encoder.encode_input(e || 0) }
383
+ output, hidden_state = net.predict(input, hidden_state)
384
+ output.collect { |b| encoder.decode_output(b) }
385
+ end
386
+
387
+ s.each_with_index do |c, i|
388
+ input = data[i, options.sequence_size]
389
+ puts("#{i} #{input.inspect} -> #{c.inspect}\t#{encoder.decode_sequence(input)} -> #{encoder.decode_sequence(c)}")
390
+ end
391
+
392
+ puts(encoder.decode_sequence(s.collect(&:first)))
393
+ puts(encoder.decode_sequence(s.collect(&:last)))
394
+
395
+ hidden_state = nil
396
+ c = data[rand(data.size)]
397
+ s = data.size.times.collect do |i|
398
+ o, hidden_state = net.predict(encoder.encode_input(c), hidden_state)
399
+ c = encoder.decode_output(o)
400
+ end
401
+
402
+ puts(s.inspect)
403
+ puts(encoder.decode_sequence(s))
404
+ end
405
+ end