CooCoo 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (105) hide show
  1. checksums.yaml +7 -0
  2. data/.gitignore +16 -0
  3. data/CooCoo.gemspec +47 -0
  4. data/Gemfile +4 -0
  5. data/Gemfile.lock +88 -0
  6. data/README.md +123 -0
  7. data/Rakefile +81 -0
  8. data/bin/cuda-dev-info +25 -0
  9. data/bin/cuda-free +28 -0
  10. data/bin/cuda-free-trend +7 -0
  11. data/bin/ffi-gen +267 -0
  12. data/bin/spec_runner_html.sh +42 -0
  13. data/bin/trainer +198 -0
  14. data/bin/trend-cost +13 -0
  15. data/examples/char-rnn.rb +405 -0
  16. data/examples/cifar/cifar.rb +94 -0
  17. data/examples/img-similarity.rb +201 -0
  18. data/examples/math_ops.rb +57 -0
  19. data/examples/mnist.rb +365 -0
  20. data/examples/mnist_classifier.rb +293 -0
  21. data/examples/mnist_dream.rb +214 -0
  22. data/examples/seeds.rb +268 -0
  23. data/examples/seeds_dataset.txt +210 -0
  24. data/examples/t10k-images-idx3-ubyte +0 -0
  25. data/examples/t10k-labels-idx1-ubyte +0 -0
  26. data/examples/train-images-idx3-ubyte +0 -0
  27. data/examples/train-labels-idx1-ubyte +0 -0
  28. data/ext/buffer/Rakefile +50 -0
  29. data/ext/buffer/buffer.pre.cu +727 -0
  30. data/ext/buffer/matrix.pre.cu +49 -0
  31. data/lib/CooCoo.rb +1 -0
  32. data/lib/coo-coo.rb +18 -0
  33. data/lib/coo-coo/activation_functions.rb +344 -0
  34. data/lib/coo-coo/consts.rb +5 -0
  35. data/lib/coo-coo/convolution.rb +298 -0
  36. data/lib/coo-coo/core_ext.rb +75 -0
  37. data/lib/coo-coo/cost_functions.rb +91 -0
  38. data/lib/coo-coo/cuda.rb +116 -0
  39. data/lib/coo-coo/cuda/device_buffer.rb +240 -0
  40. data/lib/coo-coo/cuda/device_buffer/ffi.rb +109 -0
  41. data/lib/coo-coo/cuda/error.rb +51 -0
  42. data/lib/coo-coo/cuda/host_buffer.rb +117 -0
  43. data/lib/coo-coo/cuda/runtime.rb +157 -0
  44. data/lib/coo-coo/cuda/vector.rb +315 -0
  45. data/lib/coo-coo/data_sources.rb +2 -0
  46. data/lib/coo-coo/data_sources/xournal.rb +25 -0
  47. data/lib/coo-coo/data_sources/xournal/bitmap_stream.rb +197 -0
  48. data/lib/coo-coo/data_sources/xournal/document.rb +377 -0
  49. data/lib/coo-coo/data_sources/xournal/loader.rb +144 -0
  50. data/lib/coo-coo/data_sources/xournal/renderer.rb +101 -0
  51. data/lib/coo-coo/data_sources/xournal/saver.rb +99 -0
  52. data/lib/coo-coo/data_sources/xournal/training_document.rb +78 -0
  53. data/lib/coo-coo/data_sources/xournal/training_document/constants.rb +15 -0
  54. data/lib/coo-coo/data_sources/xournal/training_document/document_maker.rb +89 -0
  55. data/lib/coo-coo/data_sources/xournal/training_document/document_reader.rb +105 -0
  56. data/lib/coo-coo/data_sources/xournal/training_document/example.rb +37 -0
  57. data/lib/coo-coo/data_sources/xournal/training_document/sets.rb +76 -0
  58. data/lib/coo-coo/debug.rb +8 -0
  59. data/lib/coo-coo/dot.rb +129 -0
  60. data/lib/coo-coo/drawing.rb +4 -0
  61. data/lib/coo-coo/drawing/cairo_canvas.rb +100 -0
  62. data/lib/coo-coo/drawing/canvas.rb +68 -0
  63. data/lib/coo-coo/drawing/chunky_canvas.rb +101 -0
  64. data/lib/coo-coo/drawing/sixel.rb +214 -0
  65. data/lib/coo-coo/enum.rb +17 -0
  66. data/lib/coo-coo/from_name.rb +58 -0
  67. data/lib/coo-coo/fully_connected_layer.rb +205 -0
  68. data/lib/coo-coo/generation_script.rb +38 -0
  69. data/lib/coo-coo/grapher.rb +140 -0
  70. data/lib/coo-coo/image.rb +286 -0
  71. data/lib/coo-coo/layer.rb +67 -0
  72. data/lib/coo-coo/layer_factory.rb +26 -0
  73. data/lib/coo-coo/linear_layer.rb +59 -0
  74. data/lib/coo-coo/math.rb +607 -0
  75. data/lib/coo-coo/math/abstract_vector.rb +121 -0
  76. data/lib/coo-coo/math/functions.rb +39 -0
  77. data/lib/coo-coo/math/interpolation.rb +7 -0
  78. data/lib/coo-coo/network.rb +264 -0
  79. data/lib/coo-coo/neuron.rb +112 -0
  80. data/lib/coo-coo/neuron_layer.rb +168 -0
  81. data/lib/coo-coo/option_parser.rb +18 -0
  82. data/lib/coo-coo/platform.rb +17 -0
  83. data/lib/coo-coo/progress_bar.rb +11 -0
  84. data/lib/coo-coo/recurrence/backend.rb +99 -0
  85. data/lib/coo-coo/recurrence/frontend.rb +101 -0
  86. data/lib/coo-coo/sequence.rb +187 -0
  87. data/lib/coo-coo/shell.rb +2 -0
  88. data/lib/coo-coo/temporal_network.rb +291 -0
  89. data/lib/coo-coo/trainer.rb +21 -0
  90. data/lib/coo-coo/trainer/base.rb +67 -0
  91. data/lib/coo-coo/trainer/batch.rb +82 -0
  92. data/lib/coo-coo/trainer/batch_stats.rb +27 -0
  93. data/lib/coo-coo/trainer/momentum_stochastic.rb +59 -0
  94. data/lib/coo-coo/trainer/stochastic.rb +47 -0
  95. data/lib/coo-coo/transformer.rb +272 -0
  96. data/lib/coo-coo/vector_layer.rb +194 -0
  97. data/lib/coo-coo/version.rb +3 -0
  98. data/lib/coo-coo/weight_deltas.rb +23 -0
  99. data/prototypes/convolution.rb +116 -0
  100. data/prototypes/linear_drop.rb +51 -0
  101. data/prototypes/recurrent_layers.rb +79 -0
  102. data/www/images/screamer.png +0 -0
  103. data/www/images/screamer.xcf +0 -0
  104. data/www/index.html +82 -0
  105. metadata +373 -0
@@ -0,0 +1,67 @@
1
+ require 'coo-coo/consts'
2
+ require 'coo-coo/math'
3
+ require 'coo-coo/debug'
4
+ require 'coo-coo/cuda'
5
+ require 'coo-coo/layer_factory'
6
+ require 'coo-coo/neuron_layer'
7
+ require 'coo-coo/vector_layer'
8
+ require 'coo-coo/linear_layer'
9
+ require 'coo-coo/fully_connected_layer'
10
+
11
+ module CooCoo
12
+ if ENV["COOCOO_USE_VECTOR"] != "0" # && (ENV["COOCOO_USE_CUDA"] != "0" && CooCoo::CUDA.available?)
13
+ Layer = CooCoo::VectorLayer
14
+ else
15
+ Layer = CooCoo::NeuronLayer
16
+ end
17
+
18
+ CooCoo.debug("Defined CooCoo::Layer as #{Layer}")
19
+
20
+ class << Layer
21
+ #def find_type(name)
22
+ # LayerFactory.find_type(name)
23
+ # end
24
+
25
+ # def from_hash(*args)
26
+ # LayerFactory.from_hash(*args)
27
+ # end
28
+ end
29
+ end
30
+
31
+ if __FILE__ == $0
32
+ layer = CooCoo::Layer.new(4, 2, CooCoo::ActivationFunctions.from_name(ENV.fetch("ACTIVATION", "Logistic")))
33
+ inputs = [ [ 1.0, 0.0, 0.0, 0.0 ],
34
+ [ 0.0, 0.0, 1.0, 0.0 ],
35
+ [ 0.0, 1.0, 0.0, 0.0],
36
+ [ 0.0, 0.0, 0.0, 1.0 ]
37
+ ].collect do |v|
38
+ CooCoo::Vector[v]
39
+ end
40
+ targets = [ [ 1.0, 0.0 ],
41
+ [ 0.0, 1.0 ],
42
+ [ 0.0, 0.0 ],
43
+ [ 0.0, 0.0 ]
44
+ ].collect do |v|
45
+ CooCoo::Vector[v]
46
+ end
47
+
48
+ inputs.zip(targets).cycle(ENV.fetch('LOOPS', 100).to_i).each_with_index do |(input, target), i|
49
+ output, hidden_state = layer.forward(input, Hash.new)
50
+ puts("#{i}\t#{input} -> #{target}")
51
+ puts("\toutput: #{output}")
52
+
53
+ err = (output - target)
54
+ #err = err * err * 0.5
55
+ delta, hidden_state = layer.backprop(input, output, err, hidden_state)
56
+ puts("\tdelta: #{delta}")
57
+ puts("\terror: #{err}")
58
+ puts("\txfer: #{layer.transfer_error(delta)}")
59
+
60
+ layer.update_weights!(input, delta * 0.5)
61
+ end
62
+
63
+ inputs.zip(targets).each do |(input, target)|
64
+ output, hidden_state = layer.forward(input, Hash.new)
65
+ puts("#{input} -> #{output}\t#{target}")
66
+ end
67
+ end
@@ -0,0 +1,26 @@
1
+ module CooCoo
2
+ module LayerFactory
3
+ class << self
4
+ attr_reader :types
5
+
6
+ def register_type(klass)
7
+ @types ||= Hash.new
8
+ @types[klass.name.to_s] = klass
9
+ @types
10
+ end
11
+
12
+ def find_type(name)
13
+ @types && @types[name]
14
+ end
15
+
16
+ def from_hash(h, network = nil)
17
+ klass = find_type(h[:type])
18
+ if klass
19
+ klass.from_hash(h, network)
20
+ else
21
+ raise ArgumentError.new("invalid layer type #{h[:type].inspect}")
22
+ end
23
+ end
24
+ end
25
+ end
26
+ end
@@ -0,0 +1,59 @@
1
+ require 'coo-coo/math'
2
+ require 'coo-coo/activation_functions'
3
+ require 'coo-coo/layer_factory'
4
+
5
+ module CooCoo
6
+ class LinearLayer
7
+ LayerFactory.register_type(self)
8
+
9
+ attr_accessor :activation_function
10
+ attr_reader :size
11
+
12
+ def initialize(size, activation_function = CooCoo::ActivationFunctions::Identity.instance)
13
+ @size = size
14
+ @activation_function = activation_function
15
+ end
16
+
17
+ def num_inputs
18
+ size
19
+ end
20
+
21
+ def forward(input, hidden_state)
22
+ [ @activation_function.call(input), hidden_state ]
23
+ end
24
+
25
+ def backprop(input, output, errors, hidden_state)
26
+ [ errors * @activation_function.derivative(input, output), hidden_state ]
27
+ end
28
+
29
+ def transfer_error(deltas)
30
+ deltas
31
+ end
32
+
33
+ def adjust_weights!(deltas)
34
+ self
35
+ end
36
+
37
+ def weight_deltas(inputs, deltas)
38
+ deltas
39
+ end
40
+
41
+ def ==(other)
42
+ other.kind_of?(self.class) &&
43
+ num_inputs == other.num_inputs &&
44
+ size == other.size &&
45
+ activation_function == other.activation_function
46
+ end
47
+
48
+ def to_hash(network = nil)
49
+ { type: self.class.name,
50
+ size: size,
51
+ f: @activation_function.name
52
+ }
53
+ end
54
+
55
+ def self.from_hash(h, network = nil)
56
+ new(h[:size], ActivationFunctions.from_name(h[:f]))
57
+ end
58
+ end
59
+ end
@@ -0,0 +1,607 @@
1
+ require 'coo-coo/core_ext'
2
+ require 'coo-coo/math/abstract_vector'
3
+ require 'coo-coo/math/functions'
4
+ require 'coo-coo/math/interpolation'
5
+ require 'coo-coo/cuda'
6
+ require 'coo-coo/cuda/vector'
7
+
8
+ module CooCoo
9
+ module Ruby
10
+ class Vector < CooCoo::Math::AbstractVector
11
+ def initialize(length, initial_value = 0.0, &block)
12
+ raise ArgumentError.new("Invalid size for a Vector") if length <= 0
13
+
14
+ if block_given? # eat ruby's warning
15
+ @elements = Array.new(length, &block)
16
+ else
17
+ @elements = Array.new(length, initial_value)
18
+ end
19
+ end
20
+
21
+ def self.[](value, max_size = nil, default_value = 0.0)
22
+ if value.respond_to?(:[])
23
+ v = new(max_size || value.size, default_value) do |i|
24
+ value[i].to_f || default_value
25
+ end
26
+ else
27
+ v = new(max_size || value.size, default_value) do |i|
28
+ begin
29
+ value.next.to_f || default_value
30
+ rescue StopIteration
31
+ default_value
32
+ end
33
+ end
34
+ end
35
+ end
36
+
37
+ def coerce(other)
38
+ if other.respond_to?(:each)
39
+ return self.class[other], self
40
+ else
41
+ return self.class.new(self.size, other), self
42
+ end
43
+ end
44
+
45
+ def to_a
46
+ @elements
47
+ end
48
+
49
+ def to_s
50
+ values = each.collect do |e|
51
+ e.to_s
52
+ end
53
+
54
+ "[#{values.join(', ')}]"
55
+ end
56
+
57
+ def [](i, len = nil)
58
+ i = size + i if i < 0
59
+ raise RangeError.new if i >= size || i < 0
60
+
61
+ v = @elements[i, len || 1]
62
+
63
+ if len
64
+ self.class[v]
65
+ elsif v
66
+ v[0]
67
+ end
68
+ end
69
+
70
+ def []=(i, l, v = nil)
71
+ i = size + i if i < 0
72
+ raise RangeError.new if i >= size || i < 0
73
+
74
+ if v
75
+ @elements[i, l] = v
76
+ else
77
+ @elements[i] = l
78
+ end
79
+ end
80
+
81
+ def set(values)
82
+ values = [ values ].cycle(size) if values.kind_of?(Numeric)
83
+
84
+ values.each_with_index do |v, i|
85
+ break if i >= @elements.size
86
+ @elements[i] = v
87
+ end
88
+
89
+ self
90
+ end
91
+
92
+ def each(&block)
93
+ @elements.each(&block)
94
+ end
95
+
96
+ def each_with_index(&block)
97
+ each.each_with_index(&block)
98
+ end
99
+
100
+ def each_slice(n, &block)
101
+ if block
102
+ num_slices = (size / n.to_f).ceil.to_i
103
+
104
+ @elements.each_slice(n).with_index do |slice, i|
105
+ block.call(self.class[slice, n])
106
+ end
107
+ else
108
+ to_enum(__method__, n)
109
+ end
110
+ end
111
+
112
+ def resize(new_size)
113
+ if new_size > size
114
+ @elements = @elements + Array.new(new_size - size)
115
+ elsif new_size < size
116
+ @elements = @elements[0, new_size]
117
+ end
118
+ end
119
+
120
+ def append(other)
121
+ v = self.class.new(size + other.size)
122
+ each_with_index do |e, i|
123
+ v[i] = e
124
+ end
125
+ other.each_with_index do |e, i|
126
+ v[i + size] = e
127
+ end
128
+ v
129
+ end
130
+
131
+ def sum
132
+ @elements.each.sum
133
+ end
134
+
135
+ def magnitude_squared
136
+ (self * self).sum
137
+ end
138
+
139
+ def magnitude
140
+ magnitude_squared.sqrt
141
+ end
142
+
143
+ def normalize
144
+ self / magnitude
145
+ end
146
+
147
+ def dot(width, height, other, owidth = nil, oheight = nil)
148
+ if other.kind_of?(self.class) || other.respond_to?(:[])
149
+ owidth ||= width
150
+ oheight ||= height
151
+
152
+ if width * height != size
153
+ raise ArgumentError.new("width & height, #{width}x#{height} don't match our size: #{size}")
154
+ end
155
+ if owidth * oheight != other.size
156
+ raise ArgumentError.new("owidth & oheight, #{owidth}x#{oheight} don't match the argument's size: #{other.size}")
157
+ end
158
+
159
+ if width != oheight
160
+ raise ArgumentError.new("argument's height != this' width")
161
+ end
162
+
163
+ self.class[height.times.collect do |row|
164
+ owidth.times.collect do |col|
165
+ oheight.times.collect do |i|
166
+ self[row * width + i] * other[i * owidth + col]
167
+ end.sum
168
+ end
169
+ end.flatten]
170
+ else
171
+ raise ArgumentError.new("argument must be a #{self.class} or enumerable")
172
+ end
173
+ end
174
+
175
+ def +(other)
176
+ v = if other.respond_to?(:each)
177
+ raise ArgumentError.new("Size mismatch") if size != other.size
178
+ other.each.zip(each).collect do |oe, se|
179
+ se + oe
180
+ end
181
+ else
182
+ each.collect do |e|
183
+ e + other
184
+ end
185
+ end
186
+
187
+ self.class[v]
188
+ end
189
+
190
+ def -@
191
+ self * -1.0
192
+ end
193
+
194
+ def -(other)
195
+ v = if other.respond_to?(:each)
196
+ raise ArgumentError.new("Size mismatch: #{size} != #{other.size}") if size != other.size
197
+ other.each.zip(each).collect do |oe, se|
198
+ se - oe
199
+ end
200
+ else
201
+ each.collect do |e|
202
+ e - other
203
+ end
204
+ end
205
+
206
+ self.class[v]
207
+ end
208
+
209
+ def size
210
+ @elements.size
211
+ end
212
+
213
+ def length
214
+ @elements.size
215
+ end
216
+
217
+ def *(other)
218
+ v = if other.respond_to?(:each)
219
+ raise ArgumentError.new("Size mismatch") if size != other.size
220
+ other.each.zip(each).collect do |oe, se|
221
+ se * oe
222
+ end
223
+ else
224
+ each.collect do |e|
225
+ e * other
226
+ end
227
+ end
228
+
229
+ self.class[v]
230
+ end
231
+
232
+ def **(other)
233
+ v = if other.respond_to?(:each)
234
+ raise ArgumentError.new("Size mismatch") if size != other.size
235
+ other.each.zip(each).collect do |oe, se|
236
+ se ** oe
237
+ end
238
+ else
239
+ each.collect do |e|
240
+ e ** other
241
+ end
242
+ end
243
+
244
+ self.class[v]
245
+ end
246
+
247
+ def /(other)
248
+ v = if other.respond_to?(:each)
249
+ raise ArgumentError.new("Size mismatch") if size != other.size
250
+ other.each.zip(each).collect do |oe, se|
251
+ se / oe
252
+ end
253
+ else
254
+ each.collect do |e|
255
+ e / other
256
+ end
257
+ end
258
+
259
+ self.class[v]
260
+ end
261
+
262
+ def ==(other)
263
+ other && size == other.size && each.zip(other.each).all? do |a, b|
264
+ a == b || (a.nan? && b.nan?)
265
+ end || false
266
+ rescue NoMethodError
267
+ false
268
+ end
269
+
270
+ def !=(other)
271
+ !(self == other)
272
+ end
273
+
274
+ [ :<, :<=, :>=, :> ].each do |comp|
275
+ define_method(comp) do |other|
276
+ if other.respond_to?(:each)
277
+ self.class[each.zip(other.each).collect do |a, b|
278
+ a.send(comp, b) ? 1.0 : 0.0
279
+ end]
280
+ else
281
+ self.class[each.collect { |a| a.send(comp, other) ? 1.0 : 0.0 }]
282
+ end
283
+ end
284
+ end
285
+
286
+ [ :abs, :floor, :ceil, :round
287
+ ].each do |func|
288
+ define_method(func) do
289
+ self.class[@elements.collect { |v| v.send(func) }]
290
+ end
291
+ end
292
+
293
+ [ :exp,
294
+ :sin, :cos, :tan, :asin, :acos, :atan,
295
+ :sinh, :cosh, :tanh, :asinh, :acosh, :atanh
296
+ ].each do |func|
297
+ define_method(func) do
298
+ self.class[@elements.collect { |v|
299
+ begin
300
+ ::Math.send(func, v)
301
+ rescue ::Math::DomainError
302
+ Float::NAN
303
+ end
304
+ }]
305
+ end
306
+ end
307
+ end
308
+ end
309
+
310
+ module NMatrix
311
+ require 'nmatrix'
312
+
313
+ class Vector < CooCoo::Math::AbstractVector
314
+ protected
315
+ attr_reader :elements
316
+
317
+ public
318
+
319
+ def initialize(length, initial_value = 0.0, &block)
320
+ if length != nil
321
+ if length <= 0
322
+ raise ArgumentError.new("size must be larger than zero")
323
+ end
324
+ @elements = ::NMatrix.new([ 1, length ], initial_value)
325
+ if block
326
+ @elements.size.times do |i|
327
+ @elements[i] = block.call(i)
328
+ end
329
+ end
330
+ end
331
+ end
332
+
333
+ def self.[](value, max_size = nil, default_value = 0.0)
334
+ if value.kind_of?(::NMatrix)
335
+ v = new(nil)
336
+ v.instance_variable_set('@elements', value)
337
+ v
338
+ elsif value.respond_to?(:[])
339
+ v = new(max_size || value.size, default_value) do |i|
340
+ value[i] || default_value
341
+ end
342
+ else
343
+ v = new(max_size || value.size, default_value) do |i|
344
+ begin
345
+ value.next || default_value
346
+ rescue StopIteration
347
+ default_value
348
+ end
349
+ end
350
+ end
351
+ end
352
+
353
+ def self.zeros(length)
354
+ self[::NMatrix.zeros([1, length])]
355
+ end
356
+
357
+ def self.ones(length)
358
+ self[::NMatrix.ones([1, length])]
359
+ end
360
+
361
+ def coerce(other)
362
+ if other.respond_to?(:each)
363
+ return self.class[other], self
364
+ else
365
+ return self.class.new(self.size, other), self
366
+ end
367
+ end
368
+
369
+ def to_a
370
+ @elements.to_a
371
+ end
372
+
373
+ def to_s
374
+ "[" + to_a.join(", ") + "]"
375
+ end
376
+
377
+ def _dump(depth)
378
+ @elements.to_a.pack('E*')
379
+ end
380
+
381
+ def self._load(args)
382
+ arr = args.unpack('E*')
383
+ self[arr]
384
+ end
385
+
386
+ def [](i, len = nil)
387
+ i = size + i if i < 0
388
+ raise RangeError.new if i >= size || i < 0
389
+
390
+ if len
391
+ len = (size - i) if (i + len) >= size
392
+ raise ArgumentError.new("length must be > 0") if len <= 0
393
+ end
394
+
395
+ v = @elements[0, (i...(i + (len || 1))) ]
396
+
397
+ if len
398
+ self.class[v]
399
+ else
400
+ v[0]
401
+ end
402
+ end
403
+
404
+ def []=(i, l, v = nil)
405
+ i = size + i if i < 0
406
+ raise RangeError.new if i >= size || i < 0
407
+
408
+ if v
409
+ @elements[i, l] = v
410
+ else
411
+ @elements[i] = l
412
+ end
413
+ # @elements[i] = v
414
+ end
415
+
416
+ def set(values)
417
+ values = [ values ].each.cycle(size) if values.kind_of?(Numeric)
418
+
419
+ values.each_with_index do |v, i|
420
+ break if i >= @elements.size
421
+ @elements[i] = v
422
+ end
423
+
424
+ self
425
+ end
426
+
427
+ def append(other)
428
+ if other.kind_of?(self.class)
429
+ self.class[@elements.concat(other.elements)]
430
+ else
431
+ append(self.class[other])
432
+ end
433
+ end
434
+
435
+ def each(&block)
436
+ @elements.each(&block)
437
+ end
438
+
439
+ def each_with_index(&block)
440
+ @elements.each_with_index(&block)
441
+ end
442
+
443
+ def each_slice(n, &block)
444
+ if block
445
+ last_slice = (size / n.to_f).ceil.to_i
446
+
447
+ @elements.each_slice(n).with_index do |slice, i|
448
+ if i == last_slice - 1
449
+ slice = slice + Array.new(n - slice.size)
450
+ end
451
+
452
+ block.call(self.class[slice])
453
+ end
454
+ else
455
+ to_enum(__method__, n)
456
+ end
457
+ end
458
+
459
+ def sum
460
+ @elements.each.sum
461
+ end
462
+
463
+ def magnitude_squared
464
+ (self * self).sum
465
+ end
466
+
467
+ def magnitude
468
+ magnitude_squared.sqrt
469
+ end
470
+
471
+ def normalize
472
+ self / magnitude
473
+ end
474
+
475
+ def dot(width, height, other, owidth, oheight)
476
+ owidth ||= width
477
+ oheight ||= height
478
+
479
+ if other.kind_of?(self.class)
480
+ raise ArgumentError.new("invalid size") if other.size != owidth * oheight
481
+ raise ArgumentError.new("invalid size") if size != width * height
482
+
483
+ product = @elements.reshape([ height, width ]).
484
+ dot(other.elements.reshape([ oheight, owidth ]))
485
+
486
+ self.class[product.
487
+ reshape([1, height * owidth ])]
488
+ else
489
+ self.dot(width, height, self.class[other], owidth, oheight)
490
+ end
491
+ end
492
+
493
+ def +(other)
494
+ if other.kind_of?(self.class)
495
+ self.class[@elements + other.elements]
496
+ elsif other.kind_of?(Numeric)
497
+ self.class[@elements + other]
498
+ else
499
+ self + self.class[other]
500
+ end
501
+ end
502
+
503
+ def -@
504
+ self * -1.0
505
+ end
506
+
507
+ def -(other)
508
+ if other.kind_of?(self.class)
509
+ self.class[@elements - other.elements]
510
+ elsif other.kind_of?(Numeric)
511
+ self.class[@elements - other]
512
+ else
513
+ self - self.class[other]
514
+ end
515
+ end
516
+
517
+ def size
518
+ length
519
+ end
520
+
521
+ def length
522
+ @elements.shape[1]
523
+ end
524
+
525
+ def *(other)
526
+ if other.kind_of?(self.class)
527
+ self.class[@elements * other.elements]
528
+ elsif other.kind_of?(Numeric)
529
+ self.class[@elements * other]
530
+ else
531
+ self * self.class[other]
532
+ end
533
+ end
534
+
535
+ def **(other)
536
+ if other.kind_of?(self.class)
537
+ self.class[@elements ** other.elements]
538
+ elsif other.kind_of?(Numeric)
539
+ self.class[@elements ** other]
540
+ else
541
+ self ** self.class[other]
542
+ end
543
+ end
544
+
545
+ def /(other)
546
+ if other.kind_of?(self.class)
547
+ self.class[@elements / other.elements]
548
+ elsif other.kind_of?(Numeric)
549
+ self.class[@elements / other]
550
+ else
551
+ self / self.class[other]
552
+ end
553
+ end
554
+
555
+ def ==(other)
556
+ if other.kind_of?(self.class)
557
+ size == other.size && @elements == other.elements
558
+ elsif other != nil
559
+ a, b = coerce(other)
560
+ a == b
561
+ else
562
+ false
563
+ end
564
+ end
565
+
566
+ [ :<, :<=, :>=, :> ].each do |comp|
567
+ define_method(comp) do |other|
568
+ if other.kind_of?(self.class)
569
+ self.class[(@elements.send(comp, other.elements)).collect do |v|
570
+ v ? 1.0 : 0.0
571
+ end]
572
+ else
573
+ self.class[(@elements.send(comp, other)).collect do |v|
574
+ v ? 1.0 : 0.0
575
+ end]
576
+ end
577
+ end
578
+ end
579
+
580
+ [ :abs, :exp,
581
+ :floor, :ceil, :round,
582
+ :sin, :cos, :tan, :asin, :acos, :atan,
583
+ :sinh, :cosh, :tanh, :asinh, :acosh, :atanh
584
+ ].each do |func|
585
+ define_method(func) do
586
+ begin
587
+ self.class[@elements.send(func)]
588
+ rescue ::Math::DomainError
589
+ self.class[CooCoo::Ruby::Vector[self.to_a].send(func)]
590
+ end
591
+ end
592
+ end
593
+
594
+ protected
595
+ def elements
596
+ @elements
597
+ end
598
+ end
599
+ end
600
+
601
+ if ENV["COOCOO_USE_CUDA"] != "0" && CooCoo::CUDA.available?
602
+ Vector = CUDA::Vector
603
+ else
604
+ Vector = Ruby::Vector
605
+ #Vector = NMatrix::Vector
606
+ end
607
+ end