CooCoo 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (105) hide show
  1. checksums.yaml +7 -0
  2. data/.gitignore +16 -0
  3. data/CooCoo.gemspec +47 -0
  4. data/Gemfile +4 -0
  5. data/Gemfile.lock +88 -0
  6. data/README.md +123 -0
  7. data/Rakefile +81 -0
  8. data/bin/cuda-dev-info +25 -0
  9. data/bin/cuda-free +28 -0
  10. data/bin/cuda-free-trend +7 -0
  11. data/bin/ffi-gen +267 -0
  12. data/bin/spec_runner_html.sh +42 -0
  13. data/bin/trainer +198 -0
  14. data/bin/trend-cost +13 -0
  15. data/examples/char-rnn.rb +405 -0
  16. data/examples/cifar/cifar.rb +94 -0
  17. data/examples/img-similarity.rb +201 -0
  18. data/examples/math_ops.rb +57 -0
  19. data/examples/mnist.rb +365 -0
  20. data/examples/mnist_classifier.rb +293 -0
  21. data/examples/mnist_dream.rb +214 -0
  22. data/examples/seeds.rb +268 -0
  23. data/examples/seeds_dataset.txt +210 -0
  24. data/examples/t10k-images-idx3-ubyte +0 -0
  25. data/examples/t10k-labels-idx1-ubyte +0 -0
  26. data/examples/train-images-idx3-ubyte +0 -0
  27. data/examples/train-labels-idx1-ubyte +0 -0
  28. data/ext/buffer/Rakefile +50 -0
  29. data/ext/buffer/buffer.pre.cu +727 -0
  30. data/ext/buffer/matrix.pre.cu +49 -0
  31. data/lib/CooCoo.rb +1 -0
  32. data/lib/coo-coo.rb +18 -0
  33. data/lib/coo-coo/activation_functions.rb +344 -0
  34. data/lib/coo-coo/consts.rb +5 -0
  35. data/lib/coo-coo/convolution.rb +298 -0
  36. data/lib/coo-coo/core_ext.rb +75 -0
  37. data/lib/coo-coo/cost_functions.rb +91 -0
  38. data/lib/coo-coo/cuda.rb +116 -0
  39. data/lib/coo-coo/cuda/device_buffer.rb +240 -0
  40. data/lib/coo-coo/cuda/device_buffer/ffi.rb +109 -0
  41. data/lib/coo-coo/cuda/error.rb +51 -0
  42. data/lib/coo-coo/cuda/host_buffer.rb +117 -0
  43. data/lib/coo-coo/cuda/runtime.rb +157 -0
  44. data/lib/coo-coo/cuda/vector.rb +315 -0
  45. data/lib/coo-coo/data_sources.rb +2 -0
  46. data/lib/coo-coo/data_sources/xournal.rb +25 -0
  47. data/lib/coo-coo/data_sources/xournal/bitmap_stream.rb +197 -0
  48. data/lib/coo-coo/data_sources/xournal/document.rb +377 -0
  49. data/lib/coo-coo/data_sources/xournal/loader.rb +144 -0
  50. data/lib/coo-coo/data_sources/xournal/renderer.rb +101 -0
  51. data/lib/coo-coo/data_sources/xournal/saver.rb +99 -0
  52. data/lib/coo-coo/data_sources/xournal/training_document.rb +78 -0
  53. data/lib/coo-coo/data_sources/xournal/training_document/constants.rb +15 -0
  54. data/lib/coo-coo/data_sources/xournal/training_document/document_maker.rb +89 -0
  55. data/lib/coo-coo/data_sources/xournal/training_document/document_reader.rb +105 -0
  56. data/lib/coo-coo/data_sources/xournal/training_document/example.rb +37 -0
  57. data/lib/coo-coo/data_sources/xournal/training_document/sets.rb +76 -0
  58. data/lib/coo-coo/debug.rb +8 -0
  59. data/lib/coo-coo/dot.rb +129 -0
  60. data/lib/coo-coo/drawing.rb +4 -0
  61. data/lib/coo-coo/drawing/cairo_canvas.rb +100 -0
  62. data/lib/coo-coo/drawing/canvas.rb +68 -0
  63. data/lib/coo-coo/drawing/chunky_canvas.rb +101 -0
  64. data/lib/coo-coo/drawing/sixel.rb +214 -0
  65. data/lib/coo-coo/enum.rb +17 -0
  66. data/lib/coo-coo/from_name.rb +58 -0
  67. data/lib/coo-coo/fully_connected_layer.rb +205 -0
  68. data/lib/coo-coo/generation_script.rb +38 -0
  69. data/lib/coo-coo/grapher.rb +140 -0
  70. data/lib/coo-coo/image.rb +286 -0
  71. data/lib/coo-coo/layer.rb +67 -0
  72. data/lib/coo-coo/layer_factory.rb +26 -0
  73. data/lib/coo-coo/linear_layer.rb +59 -0
  74. data/lib/coo-coo/math.rb +607 -0
  75. data/lib/coo-coo/math/abstract_vector.rb +121 -0
  76. data/lib/coo-coo/math/functions.rb +39 -0
  77. data/lib/coo-coo/math/interpolation.rb +7 -0
  78. data/lib/coo-coo/network.rb +264 -0
  79. data/lib/coo-coo/neuron.rb +112 -0
  80. data/lib/coo-coo/neuron_layer.rb +168 -0
  81. data/lib/coo-coo/option_parser.rb +18 -0
  82. data/lib/coo-coo/platform.rb +17 -0
  83. data/lib/coo-coo/progress_bar.rb +11 -0
  84. data/lib/coo-coo/recurrence/backend.rb +99 -0
  85. data/lib/coo-coo/recurrence/frontend.rb +101 -0
  86. data/lib/coo-coo/sequence.rb +187 -0
  87. data/lib/coo-coo/shell.rb +2 -0
  88. data/lib/coo-coo/temporal_network.rb +291 -0
  89. data/lib/coo-coo/trainer.rb +21 -0
  90. data/lib/coo-coo/trainer/base.rb +67 -0
  91. data/lib/coo-coo/trainer/batch.rb +82 -0
  92. data/lib/coo-coo/trainer/batch_stats.rb +27 -0
  93. data/lib/coo-coo/trainer/momentum_stochastic.rb +59 -0
  94. data/lib/coo-coo/trainer/stochastic.rb +47 -0
  95. data/lib/coo-coo/transformer.rb +272 -0
  96. data/lib/coo-coo/vector_layer.rb +194 -0
  97. data/lib/coo-coo/version.rb +3 -0
  98. data/lib/coo-coo/weight_deltas.rb +23 -0
  99. data/prototypes/convolution.rb +116 -0
  100. data/prototypes/linear_drop.rb +51 -0
  101. data/prototypes/recurrent_layers.rb +79 -0
  102. data/www/images/screamer.png +0 -0
  103. data/www/images/screamer.xcf +0 -0
  104. data/www/index.html +82 -0
  105. metadata +373 -0
@@ -0,0 +1,2 @@
1
+ require 'coo-coo'
2
+ include CooCoo
@@ -0,0 +1,291 @@
1
+ require 'coo-coo/network'
2
+
3
+ module CooCoo
4
+ class TemporalNetwork
5
+ attr_reader :network
6
+ attr_accessor :backprop_limit
7
+
8
+ delegate :age, :to => :network
9
+ delegate :num_inputs, :to => :network
10
+ delegate :num_outputs, :to => :network
11
+ delegate :num_layers, :to => :network
12
+
13
+ def initialize(opts = Hash.new)
14
+ @network = opts.fetch(:network) { CooCoo::Network.new }
15
+ @backprop_limit = opts[:backprop_limit]
16
+ end
17
+
18
+ def layer(*args)
19
+ @network.layer(*args)
20
+ self
21
+ end
22
+
23
+ def layers
24
+ @network.layers
25
+ end
26
+
27
+ def prep_input(input)
28
+ if input.kind_of?(Enumerable)
29
+ CooCoo::Sequence[input.collect do |i|
30
+ @network.prep_input(i)
31
+ end]
32
+ else
33
+ @network.prep_input(input)
34
+ end
35
+ end
36
+
37
+ def prep_output_target(target)
38
+ if target.kind_of?(Enumerable)
39
+ CooCoo::Sequence[target.collect do |t|
40
+ @network.prep_output_target(t)
41
+ end]
42
+ else
43
+ @network.prep_output_target(target)
44
+ end
45
+ end
46
+
47
+ def final_output(outputs)
48
+ CooCoo::Sequence[outputs.collect { |o| @network.final_output(o) }]
49
+ end
50
+
51
+ def forward(input, hidden_state = nil, flattened = false)
52
+ if input.kind_of?(Enumerable)
53
+ o = input.collect do |i|
54
+ output, hidden_state = @network.forward(i, hidden_state, flattened)
55
+ output
56
+ end
57
+
58
+ return CooCoo::Sequence[o], hidden_state
59
+ else
60
+ @network.forward(input, hidden_state, flattened)
61
+ end
62
+ end
63
+
64
+ def predict(input, hidden_state = nil, flattened = false)
65
+ if input.kind_of?(Enumerable)
66
+ o = input.collect do |i|
67
+ outputs, hidden_state = @network.predict(i, hidden_state, flattened)
68
+ outputs
69
+ end
70
+
71
+ return o, hidden_state
72
+ else
73
+ @network.predict(input, hidden_state, flattened)
74
+ end
75
+ end
76
+
77
+ def learn(input, expecting, rate, cost_function = CostFunctions::MeanSquare, hidden_state = nil)
78
+ expecting.zip(input).each do |target, input|
79
+ n, hidden_state = @network.learn(input, target, rate, cost_function, hidden_state)
80
+ end
81
+
82
+ return self, hidden_state
83
+ end
84
+
85
+ def backprop(inputs, outputs, errors, hidden_state = nil)
86
+ errors = Sequence.new(outputs.size) { errors / outputs.size.to_f } unless errors.kind_of?(Sequence)
87
+
88
+ o = outputs.zip(inputs, errors).reverse.collect do |output, input, err|
89
+ output, hidden_state = @network.backprop(input, output, err, hidden_state)
90
+ output
91
+ end.reverse
92
+
93
+ return Sequence[o], hidden_state
94
+ end
95
+
96
+ def weight_deltas(inputs, outputs, deltas)
97
+ e = inputs.zip(outputs, deltas)
98
+ e = e.last(@backprop_limit) if @backprop_limit
99
+
100
+ deltas = e.collect do |input, output, delta|
101
+ @network.weight_deltas(input, output, delta)
102
+ end
103
+
104
+ accumulate_deltas(deltas)
105
+ end
106
+
107
+ def adjust_weights!(deltas)
108
+ @network.adjust_weights!(deltas)
109
+ self
110
+ end
111
+
112
+ def update_weights!(inputs, outputs, deltas)
113
+ adjust_weights!(weight_deltas(inputs, outputs, deltas))
114
+ end
115
+
116
+ def to_hash
117
+ @network.to_hash.merge({ type: self.class.name })
118
+ end
119
+
120
+ def update_from_hash!(h)
121
+ @network.update_from_hash!(h)
122
+ self
123
+ end
124
+
125
+ def self.from_hash(h)
126
+ net = CooCoo::Network.from_hash(h)
127
+ self.new(network: net)
128
+ end
129
+
130
+ private
131
+ def accumulate_deltas(deltas)
132
+ weight = 1.0 / deltas.size.to_f
133
+
134
+ acc = deltas[0]
135
+ deltas[1, deltas.size].each do |step|
136
+ step.each_with_index do |layer, i|
137
+ acc[i] += layer * weight
138
+ end
139
+ end
140
+
141
+ acc
142
+ end
143
+ end
144
+ end
145
+
146
+ if __FILE__ == $0
147
+ require 'coo-coo'
148
+ require 'pp'
149
+
150
+ def mark_random(v)
151
+ bingo = rand < 0.1
152
+ if bingo
153
+ v = v.dup
154
+ v[0] = 1.0
155
+ return v, true
156
+ else
157
+ return v, false
158
+ end
159
+ end
160
+
161
+ INPUT_LENGTH = 2
162
+ OUTPUT_LENGTH = 2
163
+ SEQUENCE_LENGTH = ENV.fetch('SEQUENCE_LENGTH', 6).to_i
164
+ HIDDEN_LENGTH = 10
165
+ RECURRENT_LENGTH = SEQUENCE_LENGTH * 4 # boosts the signal
166
+ DELAY = ENV.fetch('DELAY', 2).to_i
167
+ SINGLE_LAYER = (ENV.fetch('SINGLE_LAYER', 'true') == "true")
168
+
169
+ activation_function = CooCoo::ActivationFunctions.from_name(ENV.fetch('ACTIVATION', 'Logistic'))
170
+
171
+ net = CooCoo::TemporalNetwork.new
172
+ 2.times do |n|
173
+ rec = CooCoo::Recurrence::Frontend.new(INPUT_LENGTH, RECURRENT_LENGTH)
174
+ net.layer(rec)
175
+ if SINGLE_LAYER
176
+ net.layer(CooCoo::FullyConnectedLayer.new(INPUT_LENGTH + rec.recurrent_size, OUTPUT_LENGTH + rec.recurrent_size))
177
+ net.layer(CooCoo::LinearLayer.new(OUTPUT_LENGTH + rec.recurrent_size, activation_function))
178
+ else
179
+ net.layer(CooCoo::FullyConnectedLayer.new(INPUT_LENGTH + rec.recurrent_size, HIDDEN_LENGTH))
180
+ net.layer(CooCoo::LinearLayer.new(HIDDEN_LENGTH, activation_function))
181
+ net.layer(CooCoo::FullyConnectedLayer.new(HIDDEN_LENGTH, OUTPUT_LENGTH + rec.recurrent_size))
182
+ net.layer(CooCoo::LinearLayer.new(OUTPUT_LENGTH + rec.recurrent_size, activation_function))
183
+ end
184
+ #net.layer(CooCoo::LinearLayer.new(OUTPUT_LENGTH + rec.recurrent_size, CooCoo::ActivationFunctions::LeakyReLU.instance))
185
+ #net.layer(CooCoo::LinearLayer.new(OUTPUT_LENGTH + rec.recurrent_size, CooCoo::ActivationFunctions::Normalize.instance))
186
+ #net.layer(CooCoo::LinearLayer.new(OUTPUT_LENGTH + rec.recurrent_size, CooCoo::ActivationFunctions::ShiftedSoftMax.instance))
187
+ #net.layer(CooCoo::LinearLayer.new(OUTPUT_LENGTH + rec.recurrent_size, CooCoo::ActivationFunctions::TanH.instance))
188
+ #net.layer(CooCoo::LinearLayer.new(OUTPUT_LENGTH + rec.recurrent_size, CooCoo::ActivationFunctions::ZeroSafeMinMax.instance))
189
+ net.layer(rec.backend)
190
+ end
191
+ #net.layer(CooCoo::LinearLayer.new(OUTPUT_LENGTH + rec.recurrent_size, CooCoo::ActivationFunctions::ReLU.instance))
192
+ #net.layer(CooCoo::LinearLayer.new(OUTPUT_LENGTH, CooCoo::ActivationFunctions::ShiftedSoftMax.instance))
193
+ #net.layer(CooCoo::LinearLayer.new(OUTPUT_LENGTH, CooCoo::ActivationFunctions::Normalize.instance))
194
+ #net.layer(CooCoo::LinearLayer.new(OUTPUT_LENGTH, CooCoo::ActivationFunctions::Logistic.instance))
195
+
196
+ input_seqs = 2.times.collect do
197
+ SEQUENCE_LENGTH.times.collect do
198
+ CooCoo::Vector.zeros(INPUT_LENGTH)
199
+ end
200
+ end
201
+ input_seqs << SEQUENCE_LENGTH.times.collect { 0.45 * CooCoo::Vector.rand(INPUT_LENGTH) }
202
+ input_seqs << SEQUENCE_LENGTH.times.collect { 0.5 * CooCoo::Vector.rand(INPUT_LENGTH) }
203
+ input_seqs.first[0][0] = 1.0
204
+ input_seqs.last[0][0] = 1.0
205
+
206
+ target_seqs = input_seqs.length.times.collect do
207
+ SEQUENCE_LENGTH.times.collect do
208
+ CooCoo::Vector.zeros(OUTPUT_LENGTH)
209
+ end
210
+ end
211
+ target_seqs.first[DELAY][0] = 1.0
212
+ target_seqs.last[DELAY][0] = 1.0
213
+
214
+ def cost(net, expecting, outputs)
215
+ CooCoo::Sequence[outputs.zip(expecting).collect do |output, target|
216
+ CooCoo::CostFunctions::MeanSquare.derivative(net.prep_output_target(target), output.last)
217
+ end]
218
+ end
219
+
220
+ learning_rate = ENV.fetch("RATE", 0.3).to_f
221
+ print_rate = ENV.fetch("PRINT_RATE", 500).to_i
222
+
223
+ ENV.fetch("LOOPS", 100).to_i.times do |n|
224
+ input_seqs.zip(target_seqs).each do |input_seq, target_seq|
225
+ fuzz = Random.rand(input_seq.length)
226
+ input_seq = input_seq.rotate(fuzz)
227
+ target_seq = target_seq.rotate(fuzz)
228
+
229
+ outputs, hidden_state = net.forward(input_seq, Hash.new)
230
+
231
+ if n % print_rate == 0
232
+ input_seq.zip(outputs, target_seq).each do |input, output, target|
233
+ puts("#{n}\t#{input} -> #{target}\n\t#{output.join("\n\t")}\n")
234
+ end
235
+ end
236
+
237
+ c = cost(net, net.prep_output_target(target_seq), outputs)
238
+ all_deltas, hidden_state = net.backprop(input_seq, outputs, c, hidden_state)
239
+ net.update_weights!(input_seq, outputs, all_deltas * learning_rate)
240
+ if n % 500 == 0
241
+ puts("\tcost\t#{(c * c).sum}\n\t\t#{c.to_a.join("\n\t\t")}")
242
+ puts
243
+ end
244
+ end
245
+ end
246
+
247
+ puts
248
+
249
+ 2.times do |n|
250
+ input_seqs.zip(target_seqs).each_with_index do |(input_seq, target_seq), i|
251
+ input_seq = input_seq.collect do |input|
252
+ input, bingo = mark_random(input)
253
+ input
254
+ end
255
+
256
+ outputs, hidden_state = net.predict(input_seq, Hash.new)
257
+
258
+ outputs.zip(input_seq, target_seq).each_with_index do |(output, input, target), ii|
259
+ bingo = input[0] == 1.0
260
+ puts("#{n},#{i},#{ii}\t#{bingo ? '*' : ''}#{input} -> #{target}\t#{output}")
261
+ end
262
+ end
263
+ end
264
+
265
+ hidden_state = nil
266
+ input = CooCoo::Vector.zeros(INPUT_LENGTH)
267
+ input[0] = 1.0
268
+ outputs = (SEQUENCE_LENGTH * 2).times.collect do |n|
269
+ output, hidden_state = net.predict(input, hidden_state)
270
+ puts("#{n}\t#{input}\t#{output}")
271
+ input[0] = 0.0
272
+
273
+ output
274
+ end
275
+
276
+ outputs = outputs.collect { |o| o[0] }
277
+ (min, min_i), (max, max_i) = outputs.each_with_index.minmax
278
+ puts("Min output index = #{min_i}\t#{min_i == 0}")
279
+ puts("Max output index = #{max_i}\t#{max_i == DELAY}")
280
+ puts("output[0] is <MAX = #{outputs[0] < max}")
281
+ puts("output[DELAY] is > [0] = #{outputs[DELAY] > outputs[0]}")
282
+ puts("output[DELAY] is > [DELAY-1] = #{outputs[DELAY] > outputs[DELAY-1]}")
283
+ puts("output[DELAY] is > [DELAY-1] = #{outputs[DELAY] > outputs[DELAY+1]}")
284
+ puts("Max output index - 1 is <MAX = #{outputs[max_i-1] < max}")
285
+ if max_i < outputs.length - 1
286
+ puts("Max output index + 1 is <MAX = #{outputs[max_i+1] < max}")
287
+ end
288
+
289
+ puts
290
+ pp(net.to_hash)
291
+ end
@@ -0,0 +1,21 @@
1
+ require 'parallel'
2
+ require 'coo-coo/consts'
3
+ require 'coo-coo/debug'
4
+ require 'coo-coo/trainer/stochastic'
5
+ require 'coo-coo/trainer/momentum_stochastic'
6
+ require 'coo-coo/trainer/batch'
7
+
8
+ module CooCoo
9
+ module Trainer
10
+ def self.list
11
+ constants.
12
+ select { |c| const_get(c).ancestors.include?(Base) }.
13
+ collect(&:to_s).
14
+ sort
15
+ end
16
+
17
+ def self.from_name(name)
18
+ const_get(name).instance
19
+ end
20
+ end
21
+ end
@@ -0,0 +1,67 @@
1
+ require 'singleton'
2
+ require 'ostruct'
3
+ require 'coo-coo/option_parser'
4
+
5
+ module CooCoo
6
+ module Trainer
7
+ # @abstract Defines and documents the interface for the trainers.
8
+ class Base
9
+ include Singleton
10
+
11
+ # Returns a user friendly name, like the class name by default.
12
+ def name
13
+ self.class.name.split('::').last
14
+ end
15
+
16
+ DEFAULT_OPTIONS = {
17
+ cost: CostFunctions::MeanSquare,
18
+ learning_rate: 1/3.0,
19
+ batch_size: 1024
20
+ }
21
+
22
+ # Returns a command line {OptionParser} to gather the trainer's
23
+ # options.
24
+ # @return [[OptionParser, OpenStruct]] an {OptionParser} to parse command line options and hash to store their values.
25
+ def options(defaults = DEFAULT_OPTIONS)
26
+ options = OpenStruct.new(defaults)
27
+
28
+ parser = OptionParser.new do |o|
29
+ o.banner = "#{name} trainer options"
30
+
31
+ o.accept(CostFunctions::Base) do |v|
32
+ CostFunctions.from_name(v)
33
+ end
34
+
35
+ o.on('--cost NAME', '--cost-function NAME', "The function to minimize during training. Choices are: #{CostFunctions.named_classes.join(', ')}", CostFunctions::Base) do |v|
36
+ options.cost_function = v
37
+ end
38
+
39
+ o.on('-r', '--rate FLOAT', '--learning-rate FLOAT', Float, 'Multiplier for the changes the network calculates.') do |n|
40
+ options.learning_rate = n
41
+ end
42
+
43
+ o.on('-n', '--batch-size INTEGER', Integer, 'Number of examples to train against before yielding.') do |n|
44
+ options.batch_size = n
45
+ end
46
+
47
+ yield(o, options) if block_given?
48
+ end
49
+
50
+ [ parser, options ]
51
+ end
52
+
53
+ # Trains a network by iterating through a set of target, input pairs.
54
+ #
55
+ # @param options [Hash, OpenStruct] Options hash
56
+ # @option options [Network, TemporalNetwork] :network The network to train.
57
+ # @option options [Array<Array<Vector, Vector>>, Enumerator<Vector, Vector>] :data An array of +[ target, input ]+ pairs to be used for the training.
58
+ # @option options [Float] :learning_rate The multiplier of change in the network's weights.
59
+ # @option options [Integer] :batch_size How many examples to pull from the training data in each batch
60
+ # @option options [CostFunctions::Base] :cost_function The function to use to calculate the loss and how to change the network from bad outputs.
61
+ # @yield [BatchStats] after each batch
62
+ def train(options, &block)
63
+ raise NotImplementedError.new
64
+ end
65
+ end
66
+ end
67
+ end
@@ -0,0 +1,82 @@
1
+ require 'coo-coo/cost_functions'
2
+ require 'coo-coo/sequence'
3
+ require 'coo-coo/trainer/base'
4
+ require 'coo-coo/trainer/batch_stats'
5
+
6
+ module CooCoo
7
+ module Trainer
8
+ # Trains a network by only adjusting the network once a batch. This opens
9
+ # up parallelism during learning as more examples can be ran at one time.
10
+ class Batch < Base
11
+ DEFAULT_OPTIONS = Base::DEFAULT_OPTIONS.merge(processes: Parallel.processor_count)
12
+
13
+ def options
14
+ super(DEFAULT_OPTIONS) do |o, options|
15
+ o.on('--processes INTEGER', Integer, 'Number of threads or processes to use for the batch.') do |n|
16
+ options.processes = n
17
+ end
18
+ end
19
+ end
20
+
21
+ # @option options [Integer] :processes How many threads or processes to use for the batch. Defaults to the processor count, {Parallel#processor_count}.
22
+ def train(options, &block)
23
+ options = options.to_h
24
+ network = options.fetch(:network)
25
+ training_data = options.fetch(:data)
26
+ learning_rate = options.fetch(:learning_rate, 0.3)
27
+ batch_size = options.fetch(:batch_size, 1024)
28
+ cost_function = options.fetch(:cost_function, CostFunctions::MeanSquare)
29
+ processes = options.fetch(:processes, Parallel.processor_count)
30
+
31
+ t = Time.now
32
+
33
+ training_data.each_slice(batch_size).with_index do |batch, i|
34
+ deltas_errors = in_parallel(processes, batch) do |(expecting, input)|
35
+ output, hidden_state = network.forward(input, Hash.new)
36
+ target = network.prep_output_target(expecting)
37
+ final_output = network.final_output(output)
38
+ errors = cost_function.derivative(target, final_output)
39
+ new_deltas, hidden_state = network.backprop(input, output, errors, hidden_state)
40
+ new_deltas = network.weight_deltas(input, output, new_deltas * learning_rate)
41
+
42
+ [ new_deltas, cost_function.call(target, final_output) ]
43
+ end
44
+
45
+ deltas, total_errors = deltas_errors.transpose
46
+ network.adjust_weights!(accumulate_deltas(deltas))
47
+
48
+ if block
49
+ block.call(BatchStats.new(self, i, batch_size, Time.now - t, CooCoo::Sequence[total_errors].sum))
50
+ end
51
+
52
+ t = Time.now
53
+ end
54
+ end
55
+
56
+ protected
57
+
58
+ def in_parallel(processes, *args, &block)
59
+ opts = if CUDA.available?
60
+ # CUDA can't fork so keep it in a single Ruby
61
+ { in_threads: processes }
62
+ else
63
+ { in_processes: processes }
64
+ end
65
+ Parallel.map(*args, opts, &block)
66
+ end
67
+
68
+ def accumulate_deltas(deltas)
69
+ weight = 1.0 / deltas.size.to_f
70
+
71
+ acc = deltas[0]
72
+ deltas[1, deltas.size].each do |step|
73
+ step.each_with_index do |layer, i|
74
+ acc[i] += layer * weight
75
+ end
76
+ end
77
+
78
+ acc
79
+ end
80
+ end
81
+ end
82
+ end