CooCoo 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. checksums.yaml +7 -0
  2. data/.gitignore +16 -0
  3. data/CooCoo.gemspec +47 -0
  4. data/Gemfile +4 -0
  5. data/Gemfile.lock +88 -0
  6. data/README.md +123 -0
  7. data/Rakefile +81 -0
  8. data/bin/cuda-dev-info +25 -0
  9. data/bin/cuda-free +28 -0
  10. data/bin/cuda-free-trend +7 -0
  11. data/bin/ffi-gen +267 -0
  12. data/bin/spec_runner_html.sh +42 -0
  13. data/bin/trainer +198 -0
  14. data/bin/trend-cost +13 -0
  15. data/examples/char-rnn.rb +405 -0
  16. data/examples/cifar/cifar.rb +94 -0
  17. data/examples/img-similarity.rb +201 -0
  18. data/examples/math_ops.rb +57 -0
  19. data/examples/mnist.rb +365 -0
  20. data/examples/mnist_classifier.rb +293 -0
  21. data/examples/mnist_dream.rb +214 -0
  22. data/examples/seeds.rb +268 -0
  23. data/examples/seeds_dataset.txt +210 -0
  24. data/examples/t10k-images-idx3-ubyte +0 -0
  25. data/examples/t10k-labels-idx1-ubyte +0 -0
  26. data/examples/train-images-idx3-ubyte +0 -0
  27. data/examples/train-labels-idx1-ubyte +0 -0
  28. data/ext/buffer/Rakefile +50 -0
  29. data/ext/buffer/buffer.pre.cu +727 -0
  30. data/ext/buffer/matrix.pre.cu +49 -0
  31. data/lib/CooCoo.rb +1 -0
  32. data/lib/coo-coo.rb +18 -0
  33. data/lib/coo-coo/activation_functions.rb +344 -0
  34. data/lib/coo-coo/consts.rb +5 -0
  35. data/lib/coo-coo/convolution.rb +298 -0
  36. data/lib/coo-coo/core_ext.rb +75 -0
  37. data/lib/coo-coo/cost_functions.rb +91 -0
  38. data/lib/coo-coo/cuda.rb +116 -0
  39. data/lib/coo-coo/cuda/device_buffer.rb +240 -0
  40. data/lib/coo-coo/cuda/device_buffer/ffi.rb +109 -0
  41. data/lib/coo-coo/cuda/error.rb +51 -0
  42. data/lib/coo-coo/cuda/host_buffer.rb +117 -0
  43. data/lib/coo-coo/cuda/runtime.rb +157 -0
  44. data/lib/coo-coo/cuda/vector.rb +315 -0
  45. data/lib/coo-coo/data_sources.rb +2 -0
  46. data/lib/coo-coo/data_sources/xournal.rb +25 -0
  47. data/lib/coo-coo/data_sources/xournal/bitmap_stream.rb +197 -0
  48. data/lib/coo-coo/data_sources/xournal/document.rb +377 -0
  49. data/lib/coo-coo/data_sources/xournal/loader.rb +144 -0
  50. data/lib/coo-coo/data_sources/xournal/renderer.rb +101 -0
  51. data/lib/coo-coo/data_sources/xournal/saver.rb +99 -0
  52. data/lib/coo-coo/data_sources/xournal/training_document.rb +78 -0
  53. data/lib/coo-coo/data_sources/xournal/training_document/constants.rb +15 -0
  54. data/lib/coo-coo/data_sources/xournal/training_document/document_maker.rb +89 -0
  55. data/lib/coo-coo/data_sources/xournal/training_document/document_reader.rb +105 -0
  56. data/lib/coo-coo/data_sources/xournal/training_document/example.rb +37 -0
  57. data/lib/coo-coo/data_sources/xournal/training_document/sets.rb +76 -0
  58. data/lib/coo-coo/debug.rb +8 -0
  59. data/lib/coo-coo/dot.rb +129 -0
  60. data/lib/coo-coo/drawing.rb +4 -0
  61. data/lib/coo-coo/drawing/cairo_canvas.rb +100 -0
  62. data/lib/coo-coo/drawing/canvas.rb +68 -0
  63. data/lib/coo-coo/drawing/chunky_canvas.rb +101 -0
  64. data/lib/coo-coo/drawing/sixel.rb +214 -0
  65. data/lib/coo-coo/enum.rb +17 -0
  66. data/lib/coo-coo/from_name.rb +58 -0
  67. data/lib/coo-coo/fully_connected_layer.rb +205 -0
  68. data/lib/coo-coo/generation_script.rb +38 -0
  69. data/lib/coo-coo/grapher.rb +140 -0
  70. data/lib/coo-coo/image.rb +286 -0
  71. data/lib/coo-coo/layer.rb +67 -0
  72. data/lib/coo-coo/layer_factory.rb +26 -0
  73. data/lib/coo-coo/linear_layer.rb +59 -0
  74. data/lib/coo-coo/math.rb +607 -0
  75. data/lib/coo-coo/math/abstract_vector.rb +121 -0
  76. data/lib/coo-coo/math/functions.rb +39 -0
  77. data/lib/coo-coo/math/interpolation.rb +7 -0
  78. data/lib/coo-coo/network.rb +264 -0
  79. data/lib/coo-coo/neuron.rb +112 -0
  80. data/lib/coo-coo/neuron_layer.rb +168 -0
  81. data/lib/coo-coo/option_parser.rb +18 -0
  82. data/lib/coo-coo/platform.rb +17 -0
  83. data/lib/coo-coo/progress_bar.rb +11 -0
  84. data/lib/coo-coo/recurrence/backend.rb +99 -0
  85. data/lib/coo-coo/recurrence/frontend.rb +101 -0
  86. data/lib/coo-coo/sequence.rb +187 -0
  87. data/lib/coo-coo/shell.rb +2 -0
  88. data/lib/coo-coo/temporal_network.rb +291 -0
  89. data/lib/coo-coo/trainer.rb +21 -0
  90. data/lib/coo-coo/trainer/base.rb +67 -0
  91. data/lib/coo-coo/trainer/batch.rb +82 -0
  92. data/lib/coo-coo/trainer/batch_stats.rb +27 -0
  93. data/lib/coo-coo/trainer/momentum_stochastic.rb +59 -0
  94. data/lib/coo-coo/trainer/stochastic.rb +47 -0
  95. data/lib/coo-coo/transformer.rb +272 -0
  96. data/lib/coo-coo/vector_layer.rb +194 -0
  97. data/lib/coo-coo/version.rb +3 -0
  98. data/lib/coo-coo/weight_deltas.rb +23 -0
  99. data/prototypes/convolution.rb +116 -0
  100. data/prototypes/linear_drop.rb +51 -0
  101. data/prototypes/recurrent_layers.rb +79 -0
  102. data/www/images/screamer.png +0 -0
  103. data/www/images/screamer.xcf +0 -0
  104. data/www/index.html +82 -0
  105. metadata +373 -0
@@ -0,0 +1,94 @@
1
+ module CooCoo::Cifar
2
+ ROOT = Pathname.new(__FILE__).dirname
3
+ BINARY_BATCHES = {
4
+ labels: ROOT.join("cifar-10-batches-bin", "batches.meta.txt"),
5
+ batches: 5.times.collect { |i| ROOT.join("cifar-10-batches-bin", "data_batch_#{i + 1}.bin") },
6
+ test_batch: ROOT.join("cifar-10-batches-bin", "test_batch.bin")
7
+ }
8
+
9
+ class LabelSet
10
+ def initialize(path = BINARY_BATCHES[:labels])
11
+ load!(path)
12
+ end
13
+
14
+ def load!(path)
15
+ @labels = File.read(path).split("\n")
16
+ end
17
+
18
+ def [](number)
19
+ @labels[number]
20
+ end
21
+ end
22
+
23
+ class Batch
24
+ WIDTH = 32
25
+ HEIGHT = 32
26
+ NUM_PIXELS = WIDTH * HEIGHT
27
+ NUM_CHANNELS = 3
28
+
29
+ attr_reader :paths
30
+
31
+ def initialize(*paths)
32
+ paths = BINARY_BATCHES[:batches] if paths.empty?
33
+ @paths = Array.new
34
+ paths.each do |p|
35
+ add(p)
36
+ end
37
+ end
38
+
39
+ def add(path)
40
+ @paths << path
41
+ self
42
+ end
43
+
44
+ def each(&block)
45
+ return to_enum(__method__) unless block_given?
46
+
47
+ @paths.each_with_index do |path, i|
48
+ enumerate_file(path, i, &block)
49
+ end
50
+ end
51
+
52
+ def enumerate_file(path, index, &block)
53
+ $stderr.puts("Loading #{path}")
54
+ File.open(path, 'rb') do |f|
55
+ i = 0
56
+ loop do
57
+ data = f.read(1 + NUM_PIXELS * NUM_CHANNELS)
58
+ break unless data
59
+
60
+ data = data.unpack('C*')
61
+ label = data[0]
62
+ red = data[1, NUM_PIXELS]
63
+ green = data[1 + NUM_PIXELS, NUM_PIXELS]
64
+ blue = data[1 + NUM_PIXELS * 2, NUM_PIXELS]
65
+ block.call(index, i, label, red, green, blue)
66
+ i += 1
67
+ end
68
+ end
69
+ end
70
+ end
71
+ end
72
+
73
+ if __FILE__ == $0
74
+ require 'chunky_png'
75
+
76
+ c = CooCoo::Cifar::Batch.new
77
+ labels = CooCoo::Cifar::LabelSet.new
78
+ first, count = ARGV.collect(&:to_i)
79
+ first ||= 0
80
+ count ||= 32
81
+ puts("Showing #{count}, skipping #{first}")
82
+ c.each.drop(first).first(count).each do |batch, i, label, red, green, blue|
83
+ file = "#{batch}-#{i}-#{labels[label]}.png"
84
+ puts(file)
85
+ png = ChunkyPNG::Image.new(32, 32)
86
+ 32.times do |y|
87
+ 32.times do |x|
88
+ p = y * 32 + x
89
+ png[x, y] = ChunkyPNG::Color.rgb(red[p], green[p], blue[p])
90
+ end
91
+ end
92
+ png.save(file, :interlace => true)
93
+ end
94
+ end
@@ -0,0 +1,201 @@
1
+ require 'coo-coo'
2
+ require 'chunky_png'
3
+
4
+ class ImageStream
5
+ attr_reader :images
6
+
7
+ def initialize(*images)
8
+ @images = images.collect { |i| load_image(i) }
9
+ end
10
+
11
+ def load_image(path)
12
+ png = ChunkyPNG::Image.from_file(path)
13
+ pixels = CooCoo::Vector.new(png.width * png.height * 3)
14
+ png.pixels.each_slice(png.width).with_index do |row, i|
15
+ pixels[i * png.width * 3, png.width * 3] = row.
16
+ collect { |p| [ ChunkyPNG::Color.r(p),
17
+ ChunkyPNG::Color.g(p),
18
+ ChunkyPNG::Color.b(p)
19
+ ] }.
20
+ flatten
21
+ end
22
+
23
+ [ path, png.width, png.height, pixels / 256.0 ]
24
+ end
25
+
26
+ def size
27
+ @images.size
28
+ end
29
+
30
+ def channels
31
+ 3
32
+ end
33
+
34
+ def each(&block)
35
+ return to_enum(__method__) unless block_given?
36
+
37
+ @images.each do |img|
38
+ yield(*img)
39
+ end
40
+ end
41
+ end
42
+
43
+ class ImageSlicer
44
+ attr_reader :slice_width
45
+ attr_reader :slice_height
46
+
47
+ def initialize(num_slices, slice_width, slice_height, image_stream, chitters = 0)
48
+ @num_slices = num_slices
49
+ @slice_width = slice_width
50
+ @slice_height = slice_height
51
+ @chitters = chitters
52
+ @stream = image_stream
53
+ end
54
+
55
+ def size
56
+ @num_slices * @stream.size * @chitters
57
+ end
58
+
59
+ def channels
60
+ @stream.channels
61
+ end
62
+
63
+ def each(&block)
64
+ return to_enum(__method__) unless block_given?
65
+
66
+ @num_slices.times do |n|
67
+ @stream.each.with_index do |(path, width, height, pixels), i|
68
+ xr = rand(width)
69
+ yr = rand(height)
70
+ half_w = @slice_width / 2
71
+ half_h = @slice_height / 2
72
+ @chitters.times do |chitter|
73
+ x = xr
74
+ x += rand(@slice_width) - half_w if @chitters > 1
75
+ x = width - @slice_width if x + @slice_width > width
76
+ y = yr
77
+ y += rand(@slice_height) - half_h if @chitters > 1
78
+ y = height - @slice_height if y + @slice_height > height
79
+
80
+ slice = pixels.slice_2d(width * channels, height, x, y, @slice_width * channels, @slice_height)
81
+ yield(path, slice, x, y)
82
+ end
83
+ end
84
+ end
85
+ end
86
+ end
87
+
88
+ class TrainingSet
89
+ attr_reader :slicer, :batch_size
90
+
91
+ def initialize(slicer, batch_size)
92
+ @slicer = slicer
93
+ @batch_size = batch_size
94
+ end
95
+
96
+ def size
97
+ @slicer.size
98
+ end
99
+
100
+ def output_size
101
+ 3
102
+ end
103
+
104
+ def input_size
105
+ 2 * @slicer.channels * @slicer.slice_width * @slicer.slice_height
106
+ end
107
+
108
+ def interleave(a, b)
109
+ img = CooCoo::Vector.zeros(input_size)
110
+ stride = @slicer.slice_width * @slicer.channels
111
+
112
+ @slicer.slice_height.times do |y|
113
+ row = y * stride
114
+ img[2 * row, stride] = a[row, stride]
115
+ img[2 * row + stride, stride] = b[row, stride]
116
+ end
117
+
118
+ img
119
+ end
120
+
121
+ def each(&block)
122
+ return to_enum(__method__) unless block_given?
123
+
124
+ @slicer.each.each_slice(@batch_size) do |batch|
125
+ batch.shuffle.zip(batch.shuffle) do |a, b|
126
+ yield([target_for(a, b), interleave(a[1], b[1])])
127
+ end
128
+ end
129
+ end
130
+
131
+ def target_for(a, b)
132
+ v = CooCoo::Vector.zeros(output_size)
133
+ if a[0] == b[0]
134
+ v[0] = 1.0
135
+ d = CooCoo::Vector[[b[2] - a[2], b[3] - a[3]]]
136
+ d = d.magnitude
137
+ v[1] = @slicer.slice_width.to_f / d - 1.0
138
+ v[1] = 1.0 if v[1] > 1.0
139
+ v[1] = 0.0 if v[1] <= 0.0
140
+ v[2] = @slicer.slice_height.to_f / d - 1.0
141
+ v[2] = 1.0 if v[2] > 1.0
142
+ v[2] = 0.0 if v[2] <= 0.0
143
+ end
144
+ v
145
+ end
146
+ end
147
+
148
+ if $0 != __FILE__
149
+ require 'pathname'
150
+ require 'ostruct'
151
+
152
+ @options = OpenStruct.new
153
+ @options.slice_width = 32
154
+ @options.slice_height = 32
155
+ @options.num_slices = 1000
156
+ @options.cycles = 100
157
+ @options.images = []
158
+ @options.chitters = 4
159
+
160
+ @opts = CooCoo::OptionParser.new do |o|
161
+ o.banner = "Image Similarity Data Generator options"
162
+
163
+ o.on('--data-path PATH') do |path|
164
+ @options.images += Dir.glob(path)
165
+ end
166
+
167
+ o.on('--data-slice-width SIZE') do |n|
168
+ @options.slice_width = n.to_i
169
+ end
170
+
171
+ o.on('--data-slice-height SIZE') do |n|
172
+ @options.slice_height = n.to_i
173
+ end
174
+
175
+ o.on('--data-slices NUM') do |n|
176
+ @options.num_slices = n.to_i
177
+ end
178
+
179
+ o.on('--data-cycles NUM') do |n|
180
+ @options.cycles = n.to_i
181
+ end
182
+
183
+ o.on('--data-chitters NUM') do |n|
184
+ @options.chitters = n.to_i
185
+ end
186
+ end
187
+
188
+ def training_set()
189
+ stream = ImageStream.new(*@options.images)
190
+ slicer = ImageSlicer.new(@options.cycles,
191
+ @options.slice_width,
192
+ @options.slice_height,
193
+ stream,
194
+ @options.chitters)
195
+ training_set = TrainingSet.new(slicer, @options.num_slices.to_i)
196
+
197
+ training_set
198
+ end
199
+
200
+ [ method(:training_set), @opts ]
201
+ end
@@ -0,0 +1,57 @@
1
+ #!/bin/env ruby
2
+
3
+ require 'coo-coo'
4
+
5
+ average = Proc.new do |m|
6
+ e = m.each
7
+ Vector[[e.sum / e.count]]
8
+ end
9
+
10
+ xor = Proc.new do |m|
11
+ Vector[[m.to_a.flatten.inject(0) { |acc, n| acc ^ (255.0 * n).to_i } / 256.0]]
12
+ end
13
+
14
+ max = Proc.new do |m|
15
+ Vector[[m.each.max]]
16
+ end
17
+
18
+ def data(n, &block)
19
+ raise ArgumentError.new("Block not given") unless block_given?
20
+
21
+ out = Array.new
22
+ n.times do
23
+ m = Vector.rand(3)
24
+ out << [ block.(m), m ]
25
+ end
26
+
27
+ out
28
+ end
29
+
30
+ def print_prediction(model, input, expecting)
31
+ output = model.forward(input)
32
+ puts("#{input} -> #{output}, expecting #{expecting}, #{expecting - output}")
33
+ end
34
+
35
+ Random.srand(123)
36
+
37
+ f = max
38
+ training_data = data(1000, &f)
39
+ model = CooCoo::Network.new()
40
+ model.layer(CooCoo::Layer.new(3, 8))
41
+ model.layer(CooCoo::Layer.new(8, 8))
42
+ model.layer(CooCoo::Layer.new(8, 1))
43
+
44
+ puts("Training")
45
+ now = Time.now
46
+ model.train(network: model,
47
+ data: training_data,
48
+ learning_rate: 0.3,
49
+ batch_size: 200)
50
+ puts("\tElapsed #{(Time.now - now) / 60.0} min")
51
+
52
+ puts("Predicting:")
53
+
54
+ print_prediction(model, training_data.first[1], training_data.first[0])
55
+ print_prediction(model, Vector[[0.5, 0.75, 0.25]], f.(Vector[[0.5, 0.75, 0.25]]))
56
+ print_prediction(model, Vector[[0.25, 0.0, 0.0]], f.(Vector[[0.25, 0.0, 0.0]]))
57
+ print_prediction(model, Vector[[1.0, 0.0, 0.0]], f.(Vector[[1.0, 0.0, 0.0]]))
@@ -0,0 +1,365 @@
1
+ require 'pathname'
2
+ require 'net/http'
3
+ require 'zlib'
4
+ require 'coo-coo/image'
5
+
6
+ module MNist
7
+ PATH = Pathname.new(__FILE__)
8
+ MNIST_URIS = [ "http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz",
9
+ "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz",
10
+ "http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz",
11
+ "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz"
12
+ ]
13
+ TRAIN_LABELS_PATH = PATH.dirname.join('train-labels-idx1-ubyte')
14
+ TRAIN_IMAGES_PATH = PATH.dirname.join('train-images-idx3-ubyte')
15
+ TEST_LABELS_PATH = PATH.dirname.join('t10k-labels-idx1-ubyte')
16
+ TEST_IMAGES_PATH = PATH.dirname.join('t10k-images-idx3-ubyte')
17
+
18
+ Width = 28
19
+ Height = 28
20
+
21
+ module Fetcher
22
+ def fetch_gzip_url(url)
23
+ data = Net::HTTP.get(url)
24
+ Zlib::GzipReader.new(StringIO.new(data)).read
25
+ end
26
+
27
+ def fetch!
28
+ MNIST_URIS.each do |uri|
29
+ uri = URI.parse(uri)
30
+ path = PATH.dirname.join(File.basename(uri.path).sub(".gz", ""))
31
+ data = fetch_gzip_url(uri)
32
+ File.open(path, "w") do |f|
33
+ f.write(data)
34
+ end
35
+ end
36
+ end
37
+ end
38
+
39
+ class Example
40
+ attr_accessor :label
41
+ attr_accessor :pixels
42
+ attr_accessor :angle
43
+ attr_accessor :offset_x
44
+ attr_accessor :offset_y
45
+
46
+ def initialize(label, pixels, angle = 0, offset_x = 0, offset_y = 0)
47
+ @label = label
48
+ @pixels = pixels
49
+ @angle = angle
50
+ @offset_x = offset_x
51
+ @offset_y = offset_y
52
+ end
53
+
54
+ def pixel(x, y)
55
+ @pixels[y * MNist::Width + x] || 0
56
+ end
57
+
58
+ def to_ascii
59
+ s = ""
60
+ 28.times do |y|
61
+ 28.times do |x|
62
+ s += char_for_pixel(pixel(x, y))
63
+ end
64
+ s += "\n"
65
+ end
66
+ s
67
+ end
68
+
69
+ def each_pixel(&block)
70
+ return to_enum(__method__) unless block_given?
71
+ 28.times do |y|
72
+ 28.times do |x|
73
+ yield(pixel(x, y))
74
+ end
75
+ end
76
+ end
77
+
78
+ private
79
+ PixelValues = ' -+X#'
80
+
81
+ def char_for_pixel(p)
82
+ PixelValues[(p / 256.0 * PixelValues.length).to_i]
83
+ end
84
+ end
85
+
86
+ class DataStreamer
87
+ def initialize(labels_path, images_path)
88
+ @labels, @size = open_labels(labels_path)
89
+ @images, @image_size = open_images(images_path)
90
+ end
91
+
92
+ def close
93
+ @labels.close
94
+ @images.close
95
+ end
96
+
97
+ def size
98
+ @size
99
+ end
100
+
101
+ def next
102
+ label = next_label
103
+ pixels = next_image
104
+ if label && pixels
105
+ Example.new(label, pixels)
106
+ end
107
+ end
108
+
109
+ private
110
+
111
+ def open_labels(path)
112
+ f = File.open(path, "rb")
113
+ magic, number = f.read(4 * 2).unpack('NN')
114
+ raise RuntimeError.new("Invalid magic number #{magic} in #{path}") if magic != 0x801
115
+
116
+ [ f, number ]
117
+ end
118
+
119
+ def next_label
120
+ l = @labels.read(1)
121
+ if l
122
+ l.unpack('C').first
123
+ else
124
+ nil
125
+ end
126
+ end
127
+
128
+ def open_images(path)
129
+ f = File.open(path, "rb")
130
+ magic, num_images, height, width = f.read(4 * 4).unpack('NNNN')
131
+ raise RuntimeError.new("Invalid magic number #{magic} in #{path}") if magic != 0x803
132
+
133
+ [ f, width * height * 1 ]
134
+ end
135
+
136
+ def next_image
137
+ p = @images.read(@image_size)
138
+ if p
139
+ p.unpack('C' * @image_size)
140
+ else
141
+ nil
142
+ end
143
+ end
144
+ end
145
+
146
+ class DataStream
147
+ def initialize(labels_path = TRAIN_LABELS_PATH, images_path = TRAIN_IMAGES_PATH)
148
+ if (labels_path == TRAIN_LABELS_PATH && images_path == TRAIN_IMAGES_PATH) ||
149
+ (labels_path == TEST_LABELS_PATH && images_path == TEST_IMAGES_PATH)
150
+ if !File.exists?(labels_path) || !File.exists?(images_path)
151
+ Fetcher.fetch!
152
+ end
153
+ end
154
+
155
+ raise ArgumentError.new("File does not exist: #{labels_path}") unless File.exists?(labels_path)
156
+ raise ArgumentError.new("File does not exist: #{images_path}") unless File.exists?(images_path)
157
+
158
+ @labels_path = labels_path
159
+ @images_path = images_path
160
+
161
+ read_metadata
162
+ end
163
+
164
+ attr_reader :size
165
+ attr_reader :width
166
+ attr_reader :height
167
+
168
+ def each(&block)
169
+ return enum_for(__method__) unless block_given?
170
+
171
+ begin
172
+ streamer = DataStreamer.new(@labels_path, @images_path)
173
+
174
+ begin
175
+ ex = streamer.next
176
+ if ex
177
+ block.call(ex)
178
+ end
179
+ end until ex == nil
180
+ ensure
181
+ streamer.close
182
+ end
183
+ end
184
+
185
+ def to_enum
186
+ each
187
+ end
188
+
189
+ private
190
+ def read_metadata
191
+ read_size
192
+ read_dimensions
193
+ end
194
+
195
+ def read_dimensions
196
+ File.open(@images_path, "rb") do |f|
197
+ magic, num_images, height, width = f.read(4 * 4).unpack('NNNN')
198
+ raise RuntimeError.new("Invalid magic number #{magic} in #{path}") if magic != 0x803
199
+
200
+ @width = width
201
+ @height = height
202
+ end
203
+ end
204
+
205
+ def read_size
206
+ File.open(@labels_path, "rb") do |f|
207
+ magic, number = f.read(4 * 2).unpack('NN')
208
+ raise RuntimeError.new("Invalid magic number #{magic} in #{@labels_path}") if magic != 0x801
209
+
210
+ @size = number
211
+ end
212
+ end
213
+
214
+ public
215
+ class Rotator < Enumerator
216
+ def initialize(data, rotations, rotation_range, random = false)
217
+ @data = data.to_enum
218
+ @rotations = rotations
219
+ @rotation_range = rotation_range
220
+ @random = random
221
+
222
+ super() do |y|
223
+ loop do
224
+ example = @data.next
225
+ @rotations.times do |r|
226
+ t = if @random
227
+ rand
228
+ else
229
+ (r / @rotations.to_f)
230
+ end
231
+ theta = t * @rotation_range - @rotation_range / 2.0
232
+ img = rotate_pixels(example.pixels, theta)
233
+ y << Example.new(example.label, img.to_a.flatten, theta)
234
+ end
235
+ end
236
+ end
237
+ end
238
+
239
+ def wrap(enum)
240
+ self.class.new(enum, @rotations, @rotation_range, @random)
241
+ end
242
+
243
+ def drop(n)
244
+ wrap(@data.drop(n))
245
+ end
246
+
247
+ def rotate_pixels(pixels, theta)
248
+ rot = CooCoo::Image::Rotate.new(MNist::Width / 2, MNist::Height / 2, theta)
249
+ img = CooCoo::Image::Base.new(MNist::Width, MNist::Height, 1, pixels.to_a.flatten)
250
+ (img * rot)
251
+ end
252
+ end
253
+
254
+ class Translator < Enumerator
255
+ def initialize(data, num_translations, dx, dy, random = false)
256
+ @data = data.to_enum
257
+ @num_translations = num_translations
258
+ @dx = dx
259
+ @dy = dy
260
+ @random = random
261
+
262
+ super() do |yielder|
263
+ loop do
264
+ example = @data.next
265
+ @num_translations.times do |r|
266
+ x = if @random
267
+ rand
268
+ else
269
+ (r / @num_translations.to_f)
270
+ end
271
+ x = x * @dx - @dx / 2.0
272
+ y = if @random
273
+ rand
274
+ else
275
+ (r / @num_translations.to_f)
276
+ end
277
+ y = y * @dy - @dy / 2.0
278
+ img = translate_pixels(example.pixels, x, y)
279
+ yielder << Example.new(example.label, img.to_a.flatten, example.angle, x, y)
280
+ end
281
+ end
282
+ end
283
+ end
284
+
285
+ def wrap(enum)
286
+ self.class.new(enum, @num_translations, @dx, @dy, @random)
287
+ end
288
+
289
+ def drop(n)
290
+ wrap(@data.drop(n))
291
+ end
292
+
293
+ def translate_pixels(pixels, x, y)
294
+ transform = CooCoo::Image::Translate.new(x, y)
295
+ img = CooCoo::Image::Base.new(MNist::Width, MNist::Height, 1, pixels.to_a.flatten)
296
+ (img * transform)
297
+ end
298
+ end
299
+ end
300
+
301
+ class TrainingSet
302
+ def initialize(data_stream = MNist::DataStream.new(MNist::TRAIN_LABELS_PATH, MNist::TRAIN_IMAGES_PATH))
303
+ @stream = data_stream
304
+ end
305
+
306
+ def each(&block)
307
+ return to_enum(__method__) unless block
308
+
309
+ enum = @stream.each
310
+ loop do
311
+ example = enum.next
312
+
313
+ a = Array.new(10, 0.0)
314
+ a[example.label] = 1.0
315
+ m = [ CooCoo::Vector[a],
316
+ CooCoo::Vector[example.pixels] / 256.0
317
+ ]
318
+ #$stderr.puts("#{m[0]}\t#{m[1]}")
319
+ block.call(m)
320
+ end
321
+ end
322
+ end
323
+
324
+ class DataSet
325
+ attr_reader :examples
326
+
327
+ def initialize
328
+ @examples = Array.new
329
+ end
330
+
331
+ def load!(labels_path, images_path)
332
+ @examples = DataStream.new(labels_path, images_path).each.to_a
333
+ self
334
+ end
335
+ end
336
+
337
+ end
338
+
339
+ if __FILE__ == $0
340
+ def print_example(ex)
341
+ puts(ex.label)
342
+ puts(ex.to_ascii)
343
+ end
344
+
345
+ data = MNist::DataStream.new(MNist::TRAIN_LABELS_PATH, MNist::TRAIN_IMAGES_PATH)
346
+ i = 0
347
+ data.each.
348
+ group_by(&:label).
349
+ collect { |label, values| [ label, values.first ] }.
350
+ sort_by(&:first).
351
+ first(10).
352
+ each do |(label, e)|
353
+ puts(i)
354
+ print_example(e)
355
+ i += 1
356
+ break if i > 20
357
+ end
358
+
359
+ rot = MNist::DataStream::Rotator.new(data.each, 10, Math::PI, false)
360
+ rot.drop(10).first(10).each do |example|
361
+ print_example(example)
362
+ end
363
+
364
+ puts("#{data.size} total #{data.width}x#{data.height} images")
365
+ end