CooCoo 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. checksums.yaml +7 -0
  2. data/.gitignore +16 -0
  3. data/CooCoo.gemspec +47 -0
  4. data/Gemfile +4 -0
  5. data/Gemfile.lock +88 -0
  6. data/README.md +123 -0
  7. data/Rakefile +81 -0
  8. data/bin/cuda-dev-info +25 -0
  9. data/bin/cuda-free +28 -0
  10. data/bin/cuda-free-trend +7 -0
  11. data/bin/ffi-gen +267 -0
  12. data/bin/spec_runner_html.sh +42 -0
  13. data/bin/trainer +198 -0
  14. data/bin/trend-cost +13 -0
  15. data/examples/char-rnn.rb +405 -0
  16. data/examples/cifar/cifar.rb +94 -0
  17. data/examples/img-similarity.rb +201 -0
  18. data/examples/math_ops.rb +57 -0
  19. data/examples/mnist.rb +365 -0
  20. data/examples/mnist_classifier.rb +293 -0
  21. data/examples/mnist_dream.rb +214 -0
  22. data/examples/seeds.rb +268 -0
  23. data/examples/seeds_dataset.txt +210 -0
  24. data/examples/t10k-images-idx3-ubyte +0 -0
  25. data/examples/t10k-labels-idx1-ubyte +0 -0
  26. data/examples/train-images-idx3-ubyte +0 -0
  27. data/examples/train-labels-idx1-ubyte +0 -0
  28. data/ext/buffer/Rakefile +50 -0
  29. data/ext/buffer/buffer.pre.cu +727 -0
  30. data/ext/buffer/matrix.pre.cu +49 -0
  31. data/lib/CooCoo.rb +1 -0
  32. data/lib/coo-coo.rb +18 -0
  33. data/lib/coo-coo/activation_functions.rb +344 -0
  34. data/lib/coo-coo/consts.rb +5 -0
  35. data/lib/coo-coo/convolution.rb +298 -0
  36. data/lib/coo-coo/core_ext.rb +75 -0
  37. data/lib/coo-coo/cost_functions.rb +91 -0
  38. data/lib/coo-coo/cuda.rb +116 -0
  39. data/lib/coo-coo/cuda/device_buffer.rb +240 -0
  40. data/lib/coo-coo/cuda/device_buffer/ffi.rb +109 -0
  41. data/lib/coo-coo/cuda/error.rb +51 -0
  42. data/lib/coo-coo/cuda/host_buffer.rb +117 -0
  43. data/lib/coo-coo/cuda/runtime.rb +157 -0
  44. data/lib/coo-coo/cuda/vector.rb +315 -0
  45. data/lib/coo-coo/data_sources.rb +2 -0
  46. data/lib/coo-coo/data_sources/xournal.rb +25 -0
  47. data/lib/coo-coo/data_sources/xournal/bitmap_stream.rb +197 -0
  48. data/lib/coo-coo/data_sources/xournal/document.rb +377 -0
  49. data/lib/coo-coo/data_sources/xournal/loader.rb +144 -0
  50. data/lib/coo-coo/data_sources/xournal/renderer.rb +101 -0
  51. data/lib/coo-coo/data_sources/xournal/saver.rb +99 -0
  52. data/lib/coo-coo/data_sources/xournal/training_document.rb +78 -0
  53. data/lib/coo-coo/data_sources/xournal/training_document/constants.rb +15 -0
  54. data/lib/coo-coo/data_sources/xournal/training_document/document_maker.rb +89 -0
  55. data/lib/coo-coo/data_sources/xournal/training_document/document_reader.rb +105 -0
  56. data/lib/coo-coo/data_sources/xournal/training_document/example.rb +37 -0
  57. data/lib/coo-coo/data_sources/xournal/training_document/sets.rb +76 -0
  58. data/lib/coo-coo/debug.rb +8 -0
  59. data/lib/coo-coo/dot.rb +129 -0
  60. data/lib/coo-coo/drawing.rb +4 -0
  61. data/lib/coo-coo/drawing/cairo_canvas.rb +100 -0
  62. data/lib/coo-coo/drawing/canvas.rb +68 -0
  63. data/lib/coo-coo/drawing/chunky_canvas.rb +101 -0
  64. data/lib/coo-coo/drawing/sixel.rb +214 -0
  65. data/lib/coo-coo/enum.rb +17 -0
  66. data/lib/coo-coo/from_name.rb +58 -0
  67. data/lib/coo-coo/fully_connected_layer.rb +205 -0
  68. data/lib/coo-coo/generation_script.rb +38 -0
  69. data/lib/coo-coo/grapher.rb +140 -0
  70. data/lib/coo-coo/image.rb +286 -0
  71. data/lib/coo-coo/layer.rb +67 -0
  72. data/lib/coo-coo/layer_factory.rb +26 -0
  73. data/lib/coo-coo/linear_layer.rb +59 -0
  74. data/lib/coo-coo/math.rb +607 -0
  75. data/lib/coo-coo/math/abstract_vector.rb +121 -0
  76. data/lib/coo-coo/math/functions.rb +39 -0
  77. data/lib/coo-coo/math/interpolation.rb +7 -0
  78. data/lib/coo-coo/network.rb +264 -0
  79. data/lib/coo-coo/neuron.rb +112 -0
  80. data/lib/coo-coo/neuron_layer.rb +168 -0
  81. data/lib/coo-coo/option_parser.rb +18 -0
  82. data/lib/coo-coo/platform.rb +17 -0
  83. data/lib/coo-coo/progress_bar.rb +11 -0
  84. data/lib/coo-coo/recurrence/backend.rb +99 -0
  85. data/lib/coo-coo/recurrence/frontend.rb +101 -0
  86. data/lib/coo-coo/sequence.rb +187 -0
  87. data/lib/coo-coo/shell.rb +2 -0
  88. data/lib/coo-coo/temporal_network.rb +291 -0
  89. data/lib/coo-coo/trainer.rb +21 -0
  90. data/lib/coo-coo/trainer/base.rb +67 -0
  91. data/lib/coo-coo/trainer/batch.rb +82 -0
  92. data/lib/coo-coo/trainer/batch_stats.rb +27 -0
  93. data/lib/coo-coo/trainer/momentum_stochastic.rb +59 -0
  94. data/lib/coo-coo/trainer/stochastic.rb +47 -0
  95. data/lib/coo-coo/transformer.rb +272 -0
  96. data/lib/coo-coo/vector_layer.rb +194 -0
  97. data/lib/coo-coo/version.rb +3 -0
  98. data/lib/coo-coo/weight_deltas.rb +23 -0
  99. data/prototypes/convolution.rb +116 -0
  100. data/prototypes/linear_drop.rb +51 -0
  101. data/prototypes/recurrent_layers.rb +79 -0
  102. data/www/images/screamer.png +0 -0
  103. data/www/images/screamer.xcf +0 -0
  104. data/www/index.html +82 -0
  105. metadata +373 -0
@@ -0,0 +1,67 @@
1
+ require 'coo-coo/consts'
2
+ require 'coo-coo/math'
3
+ require 'coo-coo/debug'
4
+ require 'coo-coo/cuda'
5
+ require 'coo-coo/layer_factory'
6
+ require 'coo-coo/neuron_layer'
7
+ require 'coo-coo/vector_layer'
8
+ require 'coo-coo/linear_layer'
9
+ require 'coo-coo/fully_connected_layer'
10
+
11
+ module CooCoo
12
+ if ENV["COOCOO_USE_VECTOR"] != "0" # && (ENV["COOCOO_USE_CUDA"] != "0" && CooCoo::CUDA.available?)
13
+ Layer = CooCoo::VectorLayer
14
+ else
15
+ Layer = CooCoo::NeuronLayer
16
+ end
17
+
18
+ CooCoo.debug("Defined CooCoo::Layer as #{Layer}")
19
+
20
+ class << Layer
21
+ #def find_type(name)
22
+ # LayerFactory.find_type(name)
23
+ # end
24
+
25
+ # def from_hash(*args)
26
+ # LayerFactory.from_hash(*args)
27
+ # end
28
+ end
29
+ end
30
+
31
+ if __FILE__ == $0
32
+ layer = CooCoo::Layer.new(4, 2, CooCoo::ActivationFunctions.from_name(ENV.fetch("ACTIVATION", "Logistic")))
33
+ inputs = [ [ 1.0, 0.0, 0.0, 0.0 ],
34
+ [ 0.0, 0.0, 1.0, 0.0 ],
35
+ [ 0.0, 1.0, 0.0, 0.0],
36
+ [ 0.0, 0.0, 0.0, 1.0 ]
37
+ ].collect do |v|
38
+ CooCoo::Vector[v]
39
+ end
40
+ targets = [ [ 1.0, 0.0 ],
41
+ [ 0.0, 1.0 ],
42
+ [ 0.0, 0.0 ],
43
+ [ 0.0, 0.0 ]
44
+ ].collect do |v|
45
+ CooCoo::Vector[v]
46
+ end
47
+
48
+ inputs.zip(targets).cycle(ENV.fetch('LOOPS', 100).to_i).each_with_index do |(input, target), i|
49
+ output, hidden_state = layer.forward(input, Hash.new)
50
+ puts("#{i}\t#{input} -> #{target}")
51
+ puts("\toutput: #{output}")
52
+
53
+ err = (output - target)
54
+ #err = err * err * 0.5
55
+ delta, hidden_state = layer.backprop(input, output, err, hidden_state)
56
+ puts("\tdelta: #{delta}")
57
+ puts("\terror: #{err}")
58
+ puts("\txfer: #{layer.transfer_error(delta)}")
59
+
60
+ layer.update_weights!(input, delta * 0.5)
61
+ end
62
+
63
+ inputs.zip(targets).each do |(input, target)|
64
+ output, hidden_state = layer.forward(input, Hash.new)
65
+ puts("#{input} -> #{output}\t#{target}")
66
+ end
67
+ end
@@ -0,0 +1,26 @@
1
+ module CooCoo
2
+ module LayerFactory
3
+ class << self
4
+ attr_reader :types
5
+
6
+ def register_type(klass)
7
+ @types ||= Hash.new
8
+ @types[klass.name.to_s] = klass
9
+ @types
10
+ end
11
+
12
+ def find_type(name)
13
+ @types && @types[name]
14
+ end
15
+
16
+ def from_hash(h, network = nil)
17
+ klass = find_type(h[:type])
18
+ if klass
19
+ klass.from_hash(h, network)
20
+ else
21
+ raise ArgumentError.new("invalid layer type #{h[:type].inspect}")
22
+ end
23
+ end
24
+ end
25
+ end
26
+ end
@@ -0,0 +1,59 @@
1
+ require 'coo-coo/math'
2
+ require 'coo-coo/activation_functions'
3
+ require 'coo-coo/layer_factory'
4
+
5
+ module CooCoo
6
+ class LinearLayer
7
+ LayerFactory.register_type(self)
8
+
9
+ attr_accessor :activation_function
10
+ attr_reader :size
11
+
12
+ def initialize(size, activation_function = CooCoo::ActivationFunctions::Identity.instance)
13
+ @size = size
14
+ @activation_function = activation_function
15
+ end
16
+
17
+ def num_inputs
18
+ size
19
+ end
20
+
21
+ def forward(input, hidden_state)
22
+ [ @activation_function.call(input), hidden_state ]
23
+ end
24
+
25
+ def backprop(input, output, errors, hidden_state)
26
+ [ errors * @activation_function.derivative(input, output), hidden_state ]
27
+ end
28
+
29
+ def transfer_error(deltas)
30
+ deltas
31
+ end
32
+
33
+ def adjust_weights!(deltas)
34
+ self
35
+ end
36
+
37
+ def weight_deltas(inputs, deltas)
38
+ deltas
39
+ end
40
+
41
+ def ==(other)
42
+ other.kind_of?(self.class) &&
43
+ num_inputs == other.num_inputs &&
44
+ size == other.size &&
45
+ activation_function == other.activation_function
46
+ end
47
+
48
+ def to_hash(network = nil)
49
+ { type: self.class.name,
50
+ size: size,
51
+ f: @activation_function.name
52
+ }
53
+ end
54
+
55
+ def self.from_hash(h, network = nil)
56
+ new(h[:size], ActivationFunctions.from_name(h[:f]))
57
+ end
58
+ end
59
+ end
@@ -0,0 +1,607 @@
1
+ require 'coo-coo/core_ext'
2
+ require 'coo-coo/math/abstract_vector'
3
+ require 'coo-coo/math/functions'
4
+ require 'coo-coo/math/interpolation'
5
+ require 'coo-coo/cuda'
6
+ require 'coo-coo/cuda/vector'
7
+
8
+ module CooCoo
9
+ module Ruby
10
+ class Vector < CooCoo::Math::AbstractVector
11
+ def initialize(length, initial_value = 0.0, &block)
12
+ raise ArgumentError.new("Invalid size for a Vector") if length <= 0
13
+
14
+ if block_given? # eat ruby's warning
15
+ @elements = Array.new(length, &block)
16
+ else
17
+ @elements = Array.new(length, initial_value)
18
+ end
19
+ end
20
+
21
+ def self.[](value, max_size = nil, default_value = 0.0)
22
+ if value.respond_to?(:[])
23
+ v = new(max_size || value.size, default_value) do |i|
24
+ value[i].to_f || default_value
25
+ end
26
+ else
27
+ v = new(max_size || value.size, default_value) do |i|
28
+ begin
29
+ value.next.to_f || default_value
30
+ rescue StopIteration
31
+ default_value
32
+ end
33
+ end
34
+ end
35
+ end
36
+
37
+ def coerce(other)
38
+ if other.respond_to?(:each)
39
+ return self.class[other], self
40
+ else
41
+ return self.class.new(self.size, other), self
42
+ end
43
+ end
44
+
45
+ def to_a
46
+ @elements
47
+ end
48
+
49
+ def to_s
50
+ values = each.collect do |e|
51
+ e.to_s
52
+ end
53
+
54
+ "[#{values.join(', ')}]"
55
+ end
56
+
57
+ def [](i, len = nil)
58
+ i = size + i if i < 0
59
+ raise RangeError.new if i >= size || i < 0
60
+
61
+ v = @elements[i, len || 1]
62
+
63
+ if len
64
+ self.class[v]
65
+ elsif v
66
+ v[0]
67
+ end
68
+ end
69
+
70
+ def []=(i, l, v = nil)
71
+ i = size + i if i < 0
72
+ raise RangeError.new if i >= size || i < 0
73
+
74
+ if v
75
+ @elements[i, l] = v
76
+ else
77
+ @elements[i] = l
78
+ end
79
+ end
80
+
81
+ def set(values)
82
+ values = [ values ].cycle(size) if values.kind_of?(Numeric)
83
+
84
+ values.each_with_index do |v, i|
85
+ break if i >= @elements.size
86
+ @elements[i] = v
87
+ end
88
+
89
+ self
90
+ end
91
+
92
+ def each(&block)
93
+ @elements.each(&block)
94
+ end
95
+
96
+ def each_with_index(&block)
97
+ each.each_with_index(&block)
98
+ end
99
+
100
+ def each_slice(n, &block)
101
+ if block
102
+ num_slices = (size / n.to_f).ceil.to_i
103
+
104
+ @elements.each_slice(n).with_index do |slice, i|
105
+ block.call(self.class[slice, n])
106
+ end
107
+ else
108
+ to_enum(__method__, n)
109
+ end
110
+ end
111
+
112
+ def resize(new_size)
113
+ if new_size > size
114
+ @elements = @elements + Array.new(new_size - size)
115
+ elsif new_size < size
116
+ @elements = @elements[0, new_size]
117
+ end
118
+ end
119
+
120
+ def append(other)
121
+ v = self.class.new(size + other.size)
122
+ each_with_index do |e, i|
123
+ v[i] = e
124
+ end
125
+ other.each_with_index do |e, i|
126
+ v[i + size] = e
127
+ end
128
+ v
129
+ end
130
+
131
+ def sum
132
+ @elements.each.sum
133
+ end
134
+
135
+ def magnitude_squared
136
+ (self * self).sum
137
+ end
138
+
139
+ def magnitude
140
+ magnitude_squared.sqrt
141
+ end
142
+
143
+ def normalize
144
+ self / magnitude
145
+ end
146
+
147
+ def dot(width, height, other, owidth = nil, oheight = nil)
148
+ if other.kind_of?(self.class) || other.respond_to?(:[])
149
+ owidth ||= width
150
+ oheight ||= height
151
+
152
+ if width * height != size
153
+ raise ArgumentError.new("width & height, #{width}x#{height} don't match our size: #{size}")
154
+ end
155
+ if owidth * oheight != other.size
156
+ raise ArgumentError.new("owidth & oheight, #{owidth}x#{oheight} don't match the argument's size: #{other.size}")
157
+ end
158
+
159
+ if width != oheight
160
+ raise ArgumentError.new("argument's height != this' width")
161
+ end
162
+
163
+ self.class[height.times.collect do |row|
164
+ owidth.times.collect do |col|
165
+ oheight.times.collect do |i|
166
+ self[row * width + i] * other[i * owidth + col]
167
+ end.sum
168
+ end
169
+ end.flatten]
170
+ else
171
+ raise ArgumentError.new("argument must be a #{self.class} or enumerable")
172
+ end
173
+ end
174
+
175
+ def +(other)
176
+ v = if other.respond_to?(:each)
177
+ raise ArgumentError.new("Size mismatch") if size != other.size
178
+ other.each.zip(each).collect do |oe, se|
179
+ se + oe
180
+ end
181
+ else
182
+ each.collect do |e|
183
+ e + other
184
+ end
185
+ end
186
+
187
+ self.class[v]
188
+ end
189
+
190
+ def -@
191
+ self * -1.0
192
+ end
193
+
194
+ def -(other)
195
+ v = if other.respond_to?(:each)
196
+ raise ArgumentError.new("Size mismatch: #{size} != #{other.size}") if size != other.size
197
+ other.each.zip(each).collect do |oe, se|
198
+ se - oe
199
+ end
200
+ else
201
+ each.collect do |e|
202
+ e - other
203
+ end
204
+ end
205
+
206
+ self.class[v]
207
+ end
208
+
209
+ def size
210
+ @elements.size
211
+ end
212
+
213
+ def length
214
+ @elements.size
215
+ end
216
+
217
+ def *(other)
218
+ v = if other.respond_to?(:each)
219
+ raise ArgumentError.new("Size mismatch") if size != other.size
220
+ other.each.zip(each).collect do |oe, se|
221
+ se * oe
222
+ end
223
+ else
224
+ each.collect do |e|
225
+ e * other
226
+ end
227
+ end
228
+
229
+ self.class[v]
230
+ end
231
+
232
+ def **(other)
233
+ v = if other.respond_to?(:each)
234
+ raise ArgumentError.new("Size mismatch") if size != other.size
235
+ other.each.zip(each).collect do |oe, se|
236
+ se ** oe
237
+ end
238
+ else
239
+ each.collect do |e|
240
+ e ** other
241
+ end
242
+ end
243
+
244
+ self.class[v]
245
+ end
246
+
247
+ def /(other)
248
+ v = if other.respond_to?(:each)
249
+ raise ArgumentError.new("Size mismatch") if size != other.size
250
+ other.each.zip(each).collect do |oe, se|
251
+ se / oe
252
+ end
253
+ else
254
+ each.collect do |e|
255
+ e / other
256
+ end
257
+ end
258
+
259
+ self.class[v]
260
+ end
261
+
262
+ def ==(other)
263
+ other && size == other.size && each.zip(other.each).all? do |a, b|
264
+ a == b || (a.nan? && b.nan?)
265
+ end || false
266
+ rescue NoMethodError
267
+ false
268
+ end
269
+
270
+ def !=(other)
271
+ !(self == other)
272
+ end
273
+
274
+ [ :<, :<=, :>=, :> ].each do |comp|
275
+ define_method(comp) do |other|
276
+ if other.respond_to?(:each)
277
+ self.class[each.zip(other.each).collect do |a, b|
278
+ a.send(comp, b) ? 1.0 : 0.0
279
+ end]
280
+ else
281
+ self.class[each.collect { |a| a.send(comp, other) ? 1.0 : 0.0 }]
282
+ end
283
+ end
284
+ end
285
+
286
+ [ :abs, :floor, :ceil, :round
287
+ ].each do |func|
288
+ define_method(func) do
289
+ self.class[@elements.collect { |v| v.send(func) }]
290
+ end
291
+ end
292
+
293
+ [ :exp,
294
+ :sin, :cos, :tan, :asin, :acos, :atan,
295
+ :sinh, :cosh, :tanh, :asinh, :acosh, :atanh
296
+ ].each do |func|
297
+ define_method(func) do
298
+ self.class[@elements.collect { |v|
299
+ begin
300
+ ::Math.send(func, v)
301
+ rescue ::Math::DomainError
302
+ Float::NAN
303
+ end
304
+ }]
305
+ end
306
+ end
307
+ end
308
+ end
309
+
310
+ module NMatrix
311
+ require 'nmatrix'
312
+
313
+ class Vector < CooCoo::Math::AbstractVector
314
+ protected
315
+ attr_reader :elements
316
+
317
+ public
318
+
319
+ def initialize(length, initial_value = 0.0, &block)
320
+ if length != nil
321
+ if length <= 0
322
+ raise ArgumentError.new("size must be larger than zero")
323
+ end
324
+ @elements = ::NMatrix.new([ 1, length ], initial_value)
325
+ if block
326
+ @elements.size.times do |i|
327
+ @elements[i] = block.call(i)
328
+ end
329
+ end
330
+ end
331
+ end
332
+
333
+ def self.[](value, max_size = nil, default_value = 0.0)
334
+ if value.kind_of?(::NMatrix)
335
+ v = new(nil)
336
+ v.instance_variable_set('@elements', value)
337
+ v
338
+ elsif value.respond_to?(:[])
339
+ v = new(max_size || value.size, default_value) do |i|
340
+ value[i] || default_value
341
+ end
342
+ else
343
+ v = new(max_size || value.size, default_value) do |i|
344
+ begin
345
+ value.next || default_value
346
+ rescue StopIteration
347
+ default_value
348
+ end
349
+ end
350
+ end
351
+ end
352
+
353
+ def self.zeros(length)
354
+ self[::NMatrix.zeros([1, length])]
355
+ end
356
+
357
+ def self.ones(length)
358
+ self[::NMatrix.ones([1, length])]
359
+ end
360
+
361
+ def coerce(other)
362
+ if other.respond_to?(:each)
363
+ return self.class[other], self
364
+ else
365
+ return self.class.new(self.size, other), self
366
+ end
367
+ end
368
+
369
+ def to_a
370
+ @elements.to_a
371
+ end
372
+
373
+ def to_s
374
+ "[" + to_a.join(", ") + "]"
375
+ end
376
+
377
+ def _dump(depth)
378
+ @elements.to_a.pack('E*')
379
+ end
380
+
381
+ def self._load(args)
382
+ arr = args.unpack('E*')
383
+ self[arr]
384
+ end
385
+
386
+ def [](i, len = nil)
387
+ i = size + i if i < 0
388
+ raise RangeError.new if i >= size || i < 0
389
+
390
+ if len
391
+ len = (size - i) if (i + len) >= size
392
+ raise ArgumentError.new("length must be > 0") if len <= 0
393
+ end
394
+
395
+ v = @elements[0, (i...(i + (len || 1))) ]
396
+
397
+ if len
398
+ self.class[v]
399
+ else
400
+ v[0]
401
+ end
402
+ end
403
+
404
+ def []=(i, l, v = nil)
405
+ i = size + i if i < 0
406
+ raise RangeError.new if i >= size || i < 0
407
+
408
+ if v
409
+ @elements[i, l] = v
410
+ else
411
+ @elements[i] = l
412
+ end
413
+ # @elements[i] = v
414
+ end
415
+
416
+ def set(values)
417
+ values = [ values ].each.cycle(size) if values.kind_of?(Numeric)
418
+
419
+ values.each_with_index do |v, i|
420
+ break if i >= @elements.size
421
+ @elements[i] = v
422
+ end
423
+
424
+ self
425
+ end
426
+
427
+ def append(other)
428
+ if other.kind_of?(self.class)
429
+ self.class[@elements.concat(other.elements)]
430
+ else
431
+ append(self.class[other])
432
+ end
433
+ end
434
+
435
+ def each(&block)
436
+ @elements.each(&block)
437
+ end
438
+
439
+ def each_with_index(&block)
440
+ @elements.each_with_index(&block)
441
+ end
442
+
443
+ def each_slice(n, &block)
444
+ if block
445
+ last_slice = (size / n.to_f).ceil.to_i
446
+
447
+ @elements.each_slice(n).with_index do |slice, i|
448
+ if i == last_slice - 1
449
+ slice = slice + Array.new(n - slice.size)
450
+ end
451
+
452
+ block.call(self.class[slice])
453
+ end
454
+ else
455
+ to_enum(__method__, n)
456
+ end
457
+ end
458
+
459
+ def sum
460
+ @elements.each.sum
461
+ end
462
+
463
+ def magnitude_squared
464
+ (self * self).sum
465
+ end
466
+
467
+ def magnitude
468
+ magnitude_squared.sqrt
469
+ end
470
+
471
+ def normalize
472
+ self / magnitude
473
+ end
474
+
475
+ def dot(width, height, other, owidth, oheight)
476
+ owidth ||= width
477
+ oheight ||= height
478
+
479
+ if other.kind_of?(self.class)
480
+ raise ArgumentError.new("invalid size") if other.size != owidth * oheight
481
+ raise ArgumentError.new("invalid size") if size != width * height
482
+
483
+ product = @elements.reshape([ height, width ]).
484
+ dot(other.elements.reshape([ oheight, owidth ]))
485
+
486
+ self.class[product.
487
+ reshape([1, height * owidth ])]
488
+ else
489
+ self.dot(width, height, self.class[other], owidth, oheight)
490
+ end
491
+ end
492
+
493
+ def +(other)
494
+ if other.kind_of?(self.class)
495
+ self.class[@elements + other.elements]
496
+ elsif other.kind_of?(Numeric)
497
+ self.class[@elements + other]
498
+ else
499
+ self + self.class[other]
500
+ end
501
+ end
502
+
503
+ def -@
504
+ self * -1.0
505
+ end
506
+
507
+ def -(other)
508
+ if other.kind_of?(self.class)
509
+ self.class[@elements - other.elements]
510
+ elsif other.kind_of?(Numeric)
511
+ self.class[@elements - other]
512
+ else
513
+ self - self.class[other]
514
+ end
515
+ end
516
+
517
+ def size
518
+ length
519
+ end
520
+
521
+ def length
522
+ @elements.shape[1]
523
+ end
524
+
525
+ def *(other)
526
+ if other.kind_of?(self.class)
527
+ self.class[@elements * other.elements]
528
+ elsif other.kind_of?(Numeric)
529
+ self.class[@elements * other]
530
+ else
531
+ self * self.class[other]
532
+ end
533
+ end
534
+
535
+ def **(other)
536
+ if other.kind_of?(self.class)
537
+ self.class[@elements ** other.elements]
538
+ elsif other.kind_of?(Numeric)
539
+ self.class[@elements ** other]
540
+ else
541
+ self ** self.class[other]
542
+ end
543
+ end
544
+
545
+ def /(other)
546
+ if other.kind_of?(self.class)
547
+ self.class[@elements / other.elements]
548
+ elsif other.kind_of?(Numeric)
549
+ self.class[@elements / other]
550
+ else
551
+ self / self.class[other]
552
+ end
553
+ end
554
+
555
+ def ==(other)
556
+ if other.kind_of?(self.class)
557
+ size == other.size && @elements == other.elements
558
+ elsif other != nil
559
+ a, b = coerce(other)
560
+ a == b
561
+ else
562
+ false
563
+ end
564
+ end
565
+
566
+ [ :<, :<=, :>=, :> ].each do |comp|
567
+ define_method(comp) do |other|
568
+ if other.kind_of?(self.class)
569
+ self.class[(@elements.send(comp, other.elements)).collect do |v|
570
+ v ? 1.0 : 0.0
571
+ end]
572
+ else
573
+ self.class[(@elements.send(comp, other)).collect do |v|
574
+ v ? 1.0 : 0.0
575
+ end]
576
+ end
577
+ end
578
+ end
579
+
580
+ [ :abs, :exp,
581
+ :floor, :ceil, :round,
582
+ :sin, :cos, :tan, :asin, :acos, :atan,
583
+ :sinh, :cosh, :tanh, :asinh, :acosh, :atanh
584
+ ].each do |func|
585
+ define_method(func) do
586
+ begin
587
+ self.class[@elements.send(func)]
588
+ rescue ::Math::DomainError
589
+ self.class[CooCoo::Ruby::Vector[self.to_a].send(func)]
590
+ end
591
+ end
592
+ end
593
+
594
+ protected
595
+ def elements
596
+ @elements
597
+ end
598
+ end
599
+ end
600
+
601
+ if ENV["COOCOO_USE_CUDA"] != "0" && CooCoo::CUDA.available?
602
+ Vector = CUDA::Vector
603
+ else
604
+ Vector = Ruby::Vector
605
+ #Vector = NMatrix::Vector
606
+ end
607
+ end