CooCoo 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. checksums.yaml +7 -0
  2. data/.gitignore +16 -0
  3. data/CooCoo.gemspec +47 -0
  4. data/Gemfile +4 -0
  5. data/Gemfile.lock +88 -0
  6. data/README.md +123 -0
  7. data/Rakefile +81 -0
  8. data/bin/cuda-dev-info +25 -0
  9. data/bin/cuda-free +28 -0
  10. data/bin/cuda-free-trend +7 -0
  11. data/bin/ffi-gen +267 -0
  12. data/bin/spec_runner_html.sh +42 -0
  13. data/bin/trainer +198 -0
  14. data/bin/trend-cost +13 -0
  15. data/examples/char-rnn.rb +405 -0
  16. data/examples/cifar/cifar.rb +94 -0
  17. data/examples/img-similarity.rb +201 -0
  18. data/examples/math_ops.rb +57 -0
  19. data/examples/mnist.rb +365 -0
  20. data/examples/mnist_classifier.rb +293 -0
  21. data/examples/mnist_dream.rb +214 -0
  22. data/examples/seeds.rb +268 -0
  23. data/examples/seeds_dataset.txt +210 -0
  24. data/examples/t10k-images-idx3-ubyte +0 -0
  25. data/examples/t10k-labels-idx1-ubyte +0 -0
  26. data/examples/train-images-idx3-ubyte +0 -0
  27. data/examples/train-labels-idx1-ubyte +0 -0
  28. data/ext/buffer/Rakefile +50 -0
  29. data/ext/buffer/buffer.pre.cu +727 -0
  30. data/ext/buffer/matrix.pre.cu +49 -0
  31. data/lib/CooCoo.rb +1 -0
  32. data/lib/coo-coo.rb +18 -0
  33. data/lib/coo-coo/activation_functions.rb +344 -0
  34. data/lib/coo-coo/consts.rb +5 -0
  35. data/lib/coo-coo/convolution.rb +298 -0
  36. data/lib/coo-coo/core_ext.rb +75 -0
  37. data/lib/coo-coo/cost_functions.rb +91 -0
  38. data/lib/coo-coo/cuda.rb +116 -0
  39. data/lib/coo-coo/cuda/device_buffer.rb +240 -0
  40. data/lib/coo-coo/cuda/device_buffer/ffi.rb +109 -0
  41. data/lib/coo-coo/cuda/error.rb +51 -0
  42. data/lib/coo-coo/cuda/host_buffer.rb +117 -0
  43. data/lib/coo-coo/cuda/runtime.rb +157 -0
  44. data/lib/coo-coo/cuda/vector.rb +315 -0
  45. data/lib/coo-coo/data_sources.rb +2 -0
  46. data/lib/coo-coo/data_sources/xournal.rb +25 -0
  47. data/lib/coo-coo/data_sources/xournal/bitmap_stream.rb +197 -0
  48. data/lib/coo-coo/data_sources/xournal/document.rb +377 -0
  49. data/lib/coo-coo/data_sources/xournal/loader.rb +144 -0
  50. data/lib/coo-coo/data_sources/xournal/renderer.rb +101 -0
  51. data/lib/coo-coo/data_sources/xournal/saver.rb +99 -0
  52. data/lib/coo-coo/data_sources/xournal/training_document.rb +78 -0
  53. data/lib/coo-coo/data_sources/xournal/training_document/constants.rb +15 -0
  54. data/lib/coo-coo/data_sources/xournal/training_document/document_maker.rb +89 -0
  55. data/lib/coo-coo/data_sources/xournal/training_document/document_reader.rb +105 -0
  56. data/lib/coo-coo/data_sources/xournal/training_document/example.rb +37 -0
  57. data/lib/coo-coo/data_sources/xournal/training_document/sets.rb +76 -0
  58. data/lib/coo-coo/debug.rb +8 -0
  59. data/lib/coo-coo/dot.rb +129 -0
  60. data/lib/coo-coo/drawing.rb +4 -0
  61. data/lib/coo-coo/drawing/cairo_canvas.rb +100 -0
  62. data/lib/coo-coo/drawing/canvas.rb +68 -0
  63. data/lib/coo-coo/drawing/chunky_canvas.rb +101 -0
  64. data/lib/coo-coo/drawing/sixel.rb +214 -0
  65. data/lib/coo-coo/enum.rb +17 -0
  66. data/lib/coo-coo/from_name.rb +58 -0
  67. data/lib/coo-coo/fully_connected_layer.rb +205 -0
  68. data/lib/coo-coo/generation_script.rb +38 -0
  69. data/lib/coo-coo/grapher.rb +140 -0
  70. data/lib/coo-coo/image.rb +286 -0
  71. data/lib/coo-coo/layer.rb +67 -0
  72. data/lib/coo-coo/layer_factory.rb +26 -0
  73. data/lib/coo-coo/linear_layer.rb +59 -0
  74. data/lib/coo-coo/math.rb +607 -0
  75. data/lib/coo-coo/math/abstract_vector.rb +121 -0
  76. data/lib/coo-coo/math/functions.rb +39 -0
  77. data/lib/coo-coo/math/interpolation.rb +7 -0
  78. data/lib/coo-coo/network.rb +264 -0
  79. data/lib/coo-coo/neuron.rb +112 -0
  80. data/lib/coo-coo/neuron_layer.rb +168 -0
  81. data/lib/coo-coo/option_parser.rb +18 -0
  82. data/lib/coo-coo/platform.rb +17 -0
  83. data/lib/coo-coo/progress_bar.rb +11 -0
  84. data/lib/coo-coo/recurrence/backend.rb +99 -0
  85. data/lib/coo-coo/recurrence/frontend.rb +101 -0
  86. data/lib/coo-coo/sequence.rb +187 -0
  87. data/lib/coo-coo/shell.rb +2 -0
  88. data/lib/coo-coo/temporal_network.rb +291 -0
  89. data/lib/coo-coo/trainer.rb +21 -0
  90. data/lib/coo-coo/trainer/base.rb +67 -0
  91. data/lib/coo-coo/trainer/batch.rb +82 -0
  92. data/lib/coo-coo/trainer/batch_stats.rb +27 -0
  93. data/lib/coo-coo/trainer/momentum_stochastic.rb +59 -0
  94. data/lib/coo-coo/trainer/stochastic.rb +47 -0
  95. data/lib/coo-coo/transformer.rb +272 -0
  96. data/lib/coo-coo/vector_layer.rb +194 -0
  97. data/lib/coo-coo/version.rb +3 -0
  98. data/lib/coo-coo/weight_deltas.rb +23 -0
  99. data/prototypes/convolution.rb +116 -0
  100. data/prototypes/linear_drop.rb +51 -0
  101. data/prototypes/recurrent_layers.rb +79 -0
  102. data/www/images/screamer.png +0 -0
  103. data/www/images/screamer.xcf +0 -0
  104. data/www/index.html +82 -0
  105. metadata +373 -0
@@ -0,0 +1,2 @@
1
+ require 'coo-coo'
2
+ include CooCoo
@@ -0,0 +1,291 @@
1
+ require 'coo-coo/network'
2
+
3
+ module CooCoo
4
+ class TemporalNetwork
5
+ attr_reader :network
6
+ attr_accessor :backprop_limit
7
+
8
+ delegate :age, :to => :network
9
+ delegate :num_inputs, :to => :network
10
+ delegate :num_outputs, :to => :network
11
+ delegate :num_layers, :to => :network
12
+
13
+ def initialize(opts = Hash.new)
14
+ @network = opts.fetch(:network) { CooCoo::Network.new }
15
+ @backprop_limit = opts[:backprop_limit]
16
+ end
17
+
18
+ def layer(*args)
19
+ @network.layer(*args)
20
+ self
21
+ end
22
+
23
+ def layers
24
+ @network.layers
25
+ end
26
+
27
+ def prep_input(input)
28
+ if input.kind_of?(Enumerable)
29
+ CooCoo::Sequence[input.collect do |i|
30
+ @network.prep_input(i)
31
+ end]
32
+ else
33
+ @network.prep_input(input)
34
+ end
35
+ end
36
+
37
+ def prep_output_target(target)
38
+ if target.kind_of?(Enumerable)
39
+ CooCoo::Sequence[target.collect do |t|
40
+ @network.prep_output_target(t)
41
+ end]
42
+ else
43
+ @network.prep_output_target(target)
44
+ end
45
+ end
46
+
47
+ def final_output(outputs)
48
+ CooCoo::Sequence[outputs.collect { |o| @network.final_output(o) }]
49
+ end
50
+
51
+ def forward(input, hidden_state = nil, flattened = false)
52
+ if input.kind_of?(Enumerable)
53
+ o = input.collect do |i|
54
+ output, hidden_state = @network.forward(i, hidden_state, flattened)
55
+ output
56
+ end
57
+
58
+ return CooCoo::Sequence[o], hidden_state
59
+ else
60
+ @network.forward(input, hidden_state, flattened)
61
+ end
62
+ end
63
+
64
+ def predict(input, hidden_state = nil, flattened = false)
65
+ if input.kind_of?(Enumerable)
66
+ o = input.collect do |i|
67
+ outputs, hidden_state = @network.predict(i, hidden_state, flattened)
68
+ outputs
69
+ end
70
+
71
+ return o, hidden_state
72
+ else
73
+ @network.predict(input, hidden_state, flattened)
74
+ end
75
+ end
76
+
77
+ def learn(input, expecting, rate, cost_function = CostFunctions::MeanSquare, hidden_state = nil)
78
+ expecting.zip(input).each do |target, input|
79
+ n, hidden_state = @network.learn(input, target, rate, cost_function, hidden_state)
80
+ end
81
+
82
+ return self, hidden_state
83
+ end
84
+
85
+ def backprop(inputs, outputs, errors, hidden_state = nil)
86
+ errors = Sequence.new(outputs.size) { errors / outputs.size.to_f } unless errors.kind_of?(Sequence)
87
+
88
+ o = outputs.zip(inputs, errors).reverse.collect do |output, input, err|
89
+ output, hidden_state = @network.backprop(input, output, err, hidden_state)
90
+ output
91
+ end.reverse
92
+
93
+ return Sequence[o], hidden_state
94
+ end
95
+
96
+ def weight_deltas(inputs, outputs, deltas)
97
+ e = inputs.zip(outputs, deltas)
98
+ e = e.last(@backprop_limit) if @backprop_limit
99
+
100
+ deltas = e.collect do |input, output, delta|
101
+ @network.weight_deltas(input, output, delta)
102
+ end
103
+
104
+ accumulate_deltas(deltas)
105
+ end
106
+
107
+ def adjust_weights!(deltas)
108
+ @network.adjust_weights!(deltas)
109
+ self
110
+ end
111
+
112
+ def update_weights!(inputs, outputs, deltas)
113
+ adjust_weights!(weight_deltas(inputs, outputs, deltas))
114
+ end
115
+
116
+ def to_hash
117
+ @network.to_hash.merge({ type: self.class.name })
118
+ end
119
+
120
+ def update_from_hash!(h)
121
+ @network.update_from_hash!(h)
122
+ self
123
+ end
124
+
125
+ def self.from_hash(h)
126
+ net = CooCoo::Network.from_hash(h)
127
+ self.new(network: net)
128
+ end
129
+
130
+ private
131
+ def accumulate_deltas(deltas)
132
+ weight = 1.0 / deltas.size.to_f
133
+
134
+ acc = deltas[0]
135
+ deltas[1, deltas.size].each do |step|
136
+ step.each_with_index do |layer, i|
137
+ acc[i] += layer * weight
138
+ end
139
+ end
140
+
141
+ acc
142
+ end
143
+ end
144
+ end
145
+
146
+ if __FILE__ == $0
147
+ require 'coo-coo'
148
+ require 'pp'
149
+
150
+ def mark_random(v)
151
+ bingo = rand < 0.1
152
+ if bingo
153
+ v = v.dup
154
+ v[0] = 1.0
155
+ return v, true
156
+ else
157
+ return v, false
158
+ end
159
+ end
160
+
161
+ INPUT_LENGTH = 2
162
+ OUTPUT_LENGTH = 2
163
+ SEQUENCE_LENGTH = ENV.fetch('SEQUENCE_LENGTH', 6).to_i
164
+ HIDDEN_LENGTH = 10
165
+ RECURRENT_LENGTH = SEQUENCE_LENGTH * 4 # boosts the signal
166
+ DELAY = ENV.fetch('DELAY', 2).to_i
167
+ SINGLE_LAYER = (ENV.fetch('SINGLE_LAYER', 'true') == "true")
168
+
169
+ activation_function = CooCoo::ActivationFunctions.from_name(ENV.fetch('ACTIVATION', 'Logistic'))
170
+
171
+ net = CooCoo::TemporalNetwork.new
172
+ 2.times do |n|
173
+ rec = CooCoo::Recurrence::Frontend.new(INPUT_LENGTH, RECURRENT_LENGTH)
174
+ net.layer(rec)
175
+ if SINGLE_LAYER
176
+ net.layer(CooCoo::FullyConnectedLayer.new(INPUT_LENGTH + rec.recurrent_size, OUTPUT_LENGTH + rec.recurrent_size))
177
+ net.layer(CooCoo::LinearLayer.new(OUTPUT_LENGTH + rec.recurrent_size, activation_function))
178
+ else
179
+ net.layer(CooCoo::FullyConnectedLayer.new(INPUT_LENGTH + rec.recurrent_size, HIDDEN_LENGTH))
180
+ net.layer(CooCoo::LinearLayer.new(HIDDEN_LENGTH, activation_function))
181
+ net.layer(CooCoo::FullyConnectedLayer.new(HIDDEN_LENGTH, OUTPUT_LENGTH + rec.recurrent_size))
182
+ net.layer(CooCoo::LinearLayer.new(OUTPUT_LENGTH + rec.recurrent_size, activation_function))
183
+ end
184
+ #net.layer(CooCoo::LinearLayer.new(OUTPUT_LENGTH + rec.recurrent_size, CooCoo::ActivationFunctions::LeakyReLU.instance))
185
+ #net.layer(CooCoo::LinearLayer.new(OUTPUT_LENGTH + rec.recurrent_size, CooCoo::ActivationFunctions::Normalize.instance))
186
+ #net.layer(CooCoo::LinearLayer.new(OUTPUT_LENGTH + rec.recurrent_size, CooCoo::ActivationFunctions::ShiftedSoftMax.instance))
187
+ #net.layer(CooCoo::LinearLayer.new(OUTPUT_LENGTH + rec.recurrent_size, CooCoo::ActivationFunctions::TanH.instance))
188
+ #net.layer(CooCoo::LinearLayer.new(OUTPUT_LENGTH + rec.recurrent_size, CooCoo::ActivationFunctions::ZeroSafeMinMax.instance))
189
+ net.layer(rec.backend)
190
+ end
191
+ #net.layer(CooCoo::LinearLayer.new(OUTPUT_LENGTH + rec.recurrent_size, CooCoo::ActivationFunctions::ReLU.instance))
192
+ #net.layer(CooCoo::LinearLayer.new(OUTPUT_LENGTH, CooCoo::ActivationFunctions::ShiftedSoftMax.instance))
193
+ #net.layer(CooCoo::LinearLayer.new(OUTPUT_LENGTH, CooCoo::ActivationFunctions::Normalize.instance))
194
+ #net.layer(CooCoo::LinearLayer.new(OUTPUT_LENGTH, CooCoo::ActivationFunctions::Logistic.instance))
195
+
196
+ input_seqs = 2.times.collect do
197
+ SEQUENCE_LENGTH.times.collect do
198
+ CooCoo::Vector.zeros(INPUT_LENGTH)
199
+ end
200
+ end
201
+ input_seqs << SEQUENCE_LENGTH.times.collect { 0.45 * CooCoo::Vector.rand(INPUT_LENGTH) }
202
+ input_seqs << SEQUENCE_LENGTH.times.collect { 0.5 * CooCoo::Vector.rand(INPUT_LENGTH) }
203
+ input_seqs.first[0][0] = 1.0
204
+ input_seqs.last[0][0] = 1.0
205
+
206
+ target_seqs = input_seqs.length.times.collect do
207
+ SEQUENCE_LENGTH.times.collect do
208
+ CooCoo::Vector.zeros(OUTPUT_LENGTH)
209
+ end
210
+ end
211
+ target_seqs.first[DELAY][0] = 1.0
212
+ target_seqs.last[DELAY][0] = 1.0
213
+
214
+ def cost(net, expecting, outputs)
215
+ CooCoo::Sequence[outputs.zip(expecting).collect do |output, target|
216
+ CooCoo::CostFunctions::MeanSquare.derivative(net.prep_output_target(target), output.last)
217
+ end]
218
+ end
219
+
220
+ learning_rate = ENV.fetch("RATE", 0.3).to_f
221
+ print_rate = ENV.fetch("PRINT_RATE", 500).to_i
222
+
223
+ ENV.fetch("LOOPS", 100).to_i.times do |n|
224
+ input_seqs.zip(target_seqs).each do |input_seq, target_seq|
225
+ fuzz = Random.rand(input_seq.length)
226
+ input_seq = input_seq.rotate(fuzz)
227
+ target_seq = target_seq.rotate(fuzz)
228
+
229
+ outputs, hidden_state = net.forward(input_seq, Hash.new)
230
+
231
+ if n % print_rate == 0
232
+ input_seq.zip(outputs, target_seq).each do |input, output, target|
233
+ puts("#{n}\t#{input} -> #{target}\n\t#{output.join("\n\t")}\n")
234
+ end
235
+ end
236
+
237
+ c = cost(net, net.prep_output_target(target_seq), outputs)
238
+ all_deltas, hidden_state = net.backprop(input_seq, outputs, c, hidden_state)
239
+ net.update_weights!(input_seq, outputs, all_deltas * learning_rate)
240
+ if n % 500 == 0
241
+ puts("\tcost\t#{(c * c).sum}\n\t\t#{c.to_a.join("\n\t\t")}")
242
+ puts
243
+ end
244
+ end
245
+ end
246
+
247
+ puts
248
+
249
+ 2.times do |n|
250
+ input_seqs.zip(target_seqs).each_with_index do |(input_seq, target_seq), i|
251
+ input_seq = input_seq.collect do |input|
252
+ input, bingo = mark_random(input)
253
+ input
254
+ end
255
+
256
+ outputs, hidden_state = net.predict(input_seq, Hash.new)
257
+
258
+ outputs.zip(input_seq, target_seq).each_with_index do |(output, input, target), ii|
259
+ bingo = input[0] == 1.0
260
+ puts("#{n},#{i},#{ii}\t#{bingo ? '*' : ''}#{input} -> #{target}\t#{output}")
261
+ end
262
+ end
263
+ end
264
+
265
+ hidden_state = nil
266
+ input = CooCoo::Vector.zeros(INPUT_LENGTH)
267
+ input[0] = 1.0
268
+ outputs = (SEQUENCE_LENGTH * 2).times.collect do |n|
269
+ output, hidden_state = net.predict(input, hidden_state)
270
+ puts("#{n}\t#{input}\t#{output}")
271
+ input[0] = 0.0
272
+
273
+ output
274
+ end
275
+
276
+ outputs = outputs.collect { |o| o[0] }
277
+ (min, min_i), (max, max_i) = outputs.each_with_index.minmax
278
+ puts("Min output index = #{min_i}\t#{min_i == 0}")
279
+ puts("Max output index = #{max_i}\t#{max_i == DELAY}")
280
+ puts("output[0] is <MAX = #{outputs[0] < max}")
281
+ puts("output[DELAY] is > [0] = #{outputs[DELAY] > outputs[0]}")
282
+ puts("output[DELAY] is > [DELAY-1] = #{outputs[DELAY] > outputs[DELAY-1]}")
283
+ puts("output[DELAY] is > [DELAY-1] = #{outputs[DELAY] > outputs[DELAY+1]}")
284
+ puts("Max output index - 1 is <MAX = #{outputs[max_i-1] < max}")
285
+ if max_i < outputs.length - 1
286
+ puts("Max output index + 1 is <MAX = #{outputs[max_i+1] < max}")
287
+ end
288
+
289
+ puts
290
+ pp(net.to_hash)
291
+ end
@@ -0,0 +1,21 @@
1
+ require 'parallel'
2
+ require 'coo-coo/consts'
3
+ require 'coo-coo/debug'
4
+ require 'coo-coo/trainer/stochastic'
5
+ require 'coo-coo/trainer/momentum_stochastic'
6
+ require 'coo-coo/trainer/batch'
7
+
8
+ module CooCoo
9
+ module Trainer
10
+ def self.list
11
+ constants.
12
+ select { |c| const_get(c).ancestors.include?(Base) }.
13
+ collect(&:to_s).
14
+ sort
15
+ end
16
+
17
+ def self.from_name(name)
18
+ const_get(name).instance
19
+ end
20
+ end
21
+ end
@@ -0,0 +1,67 @@
1
+ require 'singleton'
2
+ require 'ostruct'
3
+ require 'coo-coo/option_parser'
4
+
5
+ module CooCoo
6
+ module Trainer
7
+ # @abstract Defines and documents the interface for the trainers.
8
+ class Base
9
+ include Singleton
10
+
11
+ # Returns a user friendly name, like the class name by default.
12
+ def name
13
+ self.class.name.split('::').last
14
+ end
15
+
16
+ DEFAULT_OPTIONS = {
17
+ cost: CostFunctions::MeanSquare,
18
+ learning_rate: 1/3.0,
19
+ batch_size: 1024
20
+ }
21
+
22
+ # Returns a command line {OptionParser} to gather the trainer's
23
+ # options.
24
+ # @return [[OptionParser, OpenStruct]] an {OptionParser} to parse command line options and hash to store their values.
25
+ def options(defaults = DEFAULT_OPTIONS)
26
+ options = OpenStruct.new(defaults)
27
+
28
+ parser = OptionParser.new do |o|
29
+ o.banner = "#{name} trainer options"
30
+
31
+ o.accept(CostFunctions::Base) do |v|
32
+ CostFunctions.from_name(v)
33
+ end
34
+
35
+ o.on('--cost NAME', '--cost-function NAME', "The function to minimize during training. Choices are: #{CostFunctions.named_classes.join(', ')}", CostFunctions::Base) do |v|
36
+ options.cost_function = v
37
+ end
38
+
39
+ o.on('-r', '--rate FLOAT', '--learning-rate FLOAT', Float, 'Multiplier for the changes the network calculates.') do |n|
40
+ options.learning_rate = n
41
+ end
42
+
43
+ o.on('-n', '--batch-size INTEGER', Integer, 'Number of examples to train against before yielding.') do |n|
44
+ options.batch_size = n
45
+ end
46
+
47
+ yield(o, options) if block_given?
48
+ end
49
+
50
+ [ parser, options ]
51
+ end
52
+
53
+ # Trains a network by iterating through a set of target, input pairs.
54
+ #
55
+ # @param options [Hash, OpenStruct] Options hash
56
+ # @option options [Network, TemporalNetwork] :network The network to train.
57
+ # @option options [Array<Array<Vector, Vector>>, Enumerator<Vector, Vector>] :data An array of +[ target, input ]+ pairs to be used for the training.
58
+ # @option options [Float] :learning_rate The multiplier of change in the network's weights.
59
+ # @option options [Integer] :batch_size How many examples to pull from the training data in each batch
60
+ # @option options [CostFunctions::Base] :cost_function The function to use to calculate the loss and how to change the network from bad outputs.
61
+ # @yield [BatchStats] after each batch
62
+ def train(options, &block)
63
+ raise NotImplementedError.new
64
+ end
65
+ end
66
+ end
67
+ end
@@ -0,0 +1,82 @@
1
+ require 'coo-coo/cost_functions'
2
+ require 'coo-coo/sequence'
3
+ require 'coo-coo/trainer/base'
4
+ require 'coo-coo/trainer/batch_stats'
5
+
6
+ module CooCoo
7
+ module Trainer
8
+ # Trains a network by only adjusting the network once a batch. This opens
9
+ # up parallelism during learning as more examples can be ran at one time.
10
+ class Batch < Base
11
+ DEFAULT_OPTIONS = Base::DEFAULT_OPTIONS.merge(processes: Parallel.processor_count)
12
+
13
+ def options
14
+ super(DEFAULT_OPTIONS) do |o, options|
15
+ o.on('--processes INTEGER', Integer, 'Number of threads or processes to use for the batch.') do |n|
16
+ options.processes = n
17
+ end
18
+ end
19
+ end
20
+
21
+ # @option options [Integer] :processes How many threads or processes to use for the batch. Defaults to the processor count, {Parallel#processor_count}.
22
+ def train(options, &block)
23
+ options = options.to_h
24
+ network = options.fetch(:network)
25
+ training_data = options.fetch(:data)
26
+ learning_rate = options.fetch(:learning_rate, 0.3)
27
+ batch_size = options.fetch(:batch_size, 1024)
28
+ cost_function = options.fetch(:cost_function, CostFunctions::MeanSquare)
29
+ processes = options.fetch(:processes, Parallel.processor_count)
30
+
31
+ t = Time.now
32
+
33
+ training_data.each_slice(batch_size).with_index do |batch, i|
34
+ deltas_errors = in_parallel(processes, batch) do |(expecting, input)|
35
+ output, hidden_state = network.forward(input, Hash.new)
36
+ target = network.prep_output_target(expecting)
37
+ final_output = network.final_output(output)
38
+ errors = cost_function.derivative(target, final_output)
39
+ new_deltas, hidden_state = network.backprop(input, output, errors, hidden_state)
40
+ new_deltas = network.weight_deltas(input, output, new_deltas * learning_rate)
41
+
42
+ [ new_deltas, cost_function.call(target, final_output) ]
43
+ end
44
+
45
+ deltas, total_errors = deltas_errors.transpose
46
+ network.adjust_weights!(accumulate_deltas(deltas))
47
+
48
+ if block
49
+ block.call(BatchStats.new(self, i, batch_size, Time.now - t, CooCoo::Sequence[total_errors].sum))
50
+ end
51
+
52
+ t = Time.now
53
+ end
54
+ end
55
+
56
+ protected
57
+
58
+ def in_parallel(processes, *args, &block)
59
+ opts = if CUDA.available?
60
+ # CUDA can't fork so keep it in a single Ruby
61
+ { in_threads: processes }
62
+ else
63
+ { in_processes: processes }
64
+ end
65
+ Parallel.map(*args, opts, &block)
66
+ end
67
+
68
+ def accumulate_deltas(deltas)
69
+ weight = 1.0 / deltas.size.to_f
70
+
71
+ acc = deltas[0]
72
+ deltas[1, deltas.size].each do |step|
73
+ step.each_with_index do |layer, i|
74
+ acc[i] += layer * weight
75
+ end
76
+ end
77
+
78
+ acc
79
+ end
80
+ end
81
+ end
82
+ end