CooCoo 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (105) hide show
  1. checksums.yaml +7 -0
  2. data/.gitignore +16 -0
  3. data/CooCoo.gemspec +47 -0
  4. data/Gemfile +4 -0
  5. data/Gemfile.lock +88 -0
  6. data/README.md +123 -0
  7. data/Rakefile +81 -0
  8. data/bin/cuda-dev-info +25 -0
  9. data/bin/cuda-free +28 -0
  10. data/bin/cuda-free-trend +7 -0
  11. data/bin/ffi-gen +267 -0
  12. data/bin/spec_runner_html.sh +42 -0
  13. data/bin/trainer +198 -0
  14. data/bin/trend-cost +13 -0
  15. data/examples/char-rnn.rb +405 -0
  16. data/examples/cifar/cifar.rb +94 -0
  17. data/examples/img-similarity.rb +201 -0
  18. data/examples/math_ops.rb +57 -0
  19. data/examples/mnist.rb +365 -0
  20. data/examples/mnist_classifier.rb +293 -0
  21. data/examples/mnist_dream.rb +214 -0
  22. data/examples/seeds.rb +268 -0
  23. data/examples/seeds_dataset.txt +210 -0
  24. data/examples/t10k-images-idx3-ubyte +0 -0
  25. data/examples/t10k-labels-idx1-ubyte +0 -0
  26. data/examples/train-images-idx3-ubyte +0 -0
  27. data/examples/train-labels-idx1-ubyte +0 -0
  28. data/ext/buffer/Rakefile +50 -0
  29. data/ext/buffer/buffer.pre.cu +727 -0
  30. data/ext/buffer/matrix.pre.cu +49 -0
  31. data/lib/CooCoo.rb +1 -0
  32. data/lib/coo-coo.rb +18 -0
  33. data/lib/coo-coo/activation_functions.rb +344 -0
  34. data/lib/coo-coo/consts.rb +5 -0
  35. data/lib/coo-coo/convolution.rb +298 -0
  36. data/lib/coo-coo/core_ext.rb +75 -0
  37. data/lib/coo-coo/cost_functions.rb +91 -0
  38. data/lib/coo-coo/cuda.rb +116 -0
  39. data/lib/coo-coo/cuda/device_buffer.rb +240 -0
  40. data/lib/coo-coo/cuda/device_buffer/ffi.rb +109 -0
  41. data/lib/coo-coo/cuda/error.rb +51 -0
  42. data/lib/coo-coo/cuda/host_buffer.rb +117 -0
  43. data/lib/coo-coo/cuda/runtime.rb +157 -0
  44. data/lib/coo-coo/cuda/vector.rb +315 -0
  45. data/lib/coo-coo/data_sources.rb +2 -0
  46. data/lib/coo-coo/data_sources/xournal.rb +25 -0
  47. data/lib/coo-coo/data_sources/xournal/bitmap_stream.rb +197 -0
  48. data/lib/coo-coo/data_sources/xournal/document.rb +377 -0
  49. data/lib/coo-coo/data_sources/xournal/loader.rb +144 -0
  50. data/lib/coo-coo/data_sources/xournal/renderer.rb +101 -0
  51. data/lib/coo-coo/data_sources/xournal/saver.rb +99 -0
  52. data/lib/coo-coo/data_sources/xournal/training_document.rb +78 -0
  53. data/lib/coo-coo/data_sources/xournal/training_document/constants.rb +15 -0
  54. data/lib/coo-coo/data_sources/xournal/training_document/document_maker.rb +89 -0
  55. data/lib/coo-coo/data_sources/xournal/training_document/document_reader.rb +105 -0
  56. data/lib/coo-coo/data_sources/xournal/training_document/example.rb +37 -0
  57. data/lib/coo-coo/data_sources/xournal/training_document/sets.rb +76 -0
  58. data/lib/coo-coo/debug.rb +8 -0
  59. data/lib/coo-coo/dot.rb +129 -0
  60. data/lib/coo-coo/drawing.rb +4 -0
  61. data/lib/coo-coo/drawing/cairo_canvas.rb +100 -0
  62. data/lib/coo-coo/drawing/canvas.rb +68 -0
  63. data/lib/coo-coo/drawing/chunky_canvas.rb +101 -0
  64. data/lib/coo-coo/drawing/sixel.rb +214 -0
  65. data/lib/coo-coo/enum.rb +17 -0
  66. data/lib/coo-coo/from_name.rb +58 -0
  67. data/lib/coo-coo/fully_connected_layer.rb +205 -0
  68. data/lib/coo-coo/generation_script.rb +38 -0
  69. data/lib/coo-coo/grapher.rb +140 -0
  70. data/lib/coo-coo/image.rb +286 -0
  71. data/lib/coo-coo/layer.rb +67 -0
  72. data/lib/coo-coo/layer_factory.rb +26 -0
  73. data/lib/coo-coo/linear_layer.rb +59 -0
  74. data/lib/coo-coo/math.rb +607 -0
  75. data/lib/coo-coo/math/abstract_vector.rb +121 -0
  76. data/lib/coo-coo/math/functions.rb +39 -0
  77. data/lib/coo-coo/math/interpolation.rb +7 -0
  78. data/lib/coo-coo/network.rb +264 -0
  79. data/lib/coo-coo/neuron.rb +112 -0
  80. data/lib/coo-coo/neuron_layer.rb +168 -0
  81. data/lib/coo-coo/option_parser.rb +18 -0
  82. data/lib/coo-coo/platform.rb +17 -0
  83. data/lib/coo-coo/progress_bar.rb +11 -0
  84. data/lib/coo-coo/recurrence/backend.rb +99 -0
  85. data/lib/coo-coo/recurrence/frontend.rb +101 -0
  86. data/lib/coo-coo/sequence.rb +187 -0
  87. data/lib/coo-coo/shell.rb +2 -0
  88. data/lib/coo-coo/temporal_network.rb +291 -0
  89. data/lib/coo-coo/trainer.rb +21 -0
  90. data/lib/coo-coo/trainer/base.rb +67 -0
  91. data/lib/coo-coo/trainer/batch.rb +82 -0
  92. data/lib/coo-coo/trainer/batch_stats.rb +27 -0
  93. data/lib/coo-coo/trainer/momentum_stochastic.rb +59 -0
  94. data/lib/coo-coo/trainer/stochastic.rb +47 -0
  95. data/lib/coo-coo/transformer.rb +272 -0
  96. data/lib/coo-coo/vector_layer.rb +194 -0
  97. data/lib/coo-coo/version.rb +3 -0
  98. data/lib/coo-coo/weight_deltas.rb +23 -0
  99. data/prototypes/convolution.rb +116 -0
  100. data/prototypes/linear_drop.rb +51 -0
  101. data/prototypes/recurrent_layers.rb +79 -0
  102. data/www/images/screamer.png +0 -0
  103. data/www/images/screamer.xcf +0 -0
  104. data/www/index.html +82 -0
  105. metadata +373 -0
@@ -0,0 +1,94 @@
1
+ module CooCoo::Cifar
2
+ ROOT = Pathname.new(__FILE__).dirname
3
+ BINARY_BATCHES = {
4
+ labels: ROOT.join("cifar-10-batches-bin", "batches.meta.txt"),
5
+ batches: 5.times.collect { |i| ROOT.join("cifar-10-batches-bin", "data_batch_#{i + 1}.bin") },
6
+ test_batch: ROOT.join("cifar-10-batches-bin", "test_batch.bin")
7
+ }
8
+
9
+ class LabelSet
10
+ def initialize(path = BINARY_BATCHES[:labels])
11
+ load!(path)
12
+ end
13
+
14
+ def load!(path)
15
+ @labels = File.read(path).split("\n")
16
+ end
17
+
18
+ def [](number)
19
+ @labels[number]
20
+ end
21
+ end
22
+
23
+ class Batch
24
+ WIDTH = 32
25
+ HEIGHT = 32
26
+ NUM_PIXELS = WIDTH * HEIGHT
27
+ NUM_CHANNELS = 3
28
+
29
+ attr_reader :paths
30
+
31
+ def initialize(*paths)
32
+ paths = BINARY_BATCHES[:batches] if paths.empty?
33
+ @paths = Array.new
34
+ paths.each do |p|
35
+ add(p)
36
+ end
37
+ end
38
+
39
+ def add(path)
40
+ @paths << path
41
+ self
42
+ end
43
+
44
+ def each(&block)
45
+ return to_enum(__method__) unless block_given?
46
+
47
+ @paths.each_with_index do |path, i|
48
+ enumerate_file(path, i, &block)
49
+ end
50
+ end
51
+
52
+ def enumerate_file(path, index, &block)
53
+ $stderr.puts("Loading #{path}")
54
+ File.open(path, 'rb') do |f|
55
+ i = 0
56
+ loop do
57
+ data = f.read(1 + NUM_PIXELS * NUM_CHANNELS)
58
+ break unless data
59
+
60
+ data = data.unpack('C*')
61
+ label = data[0]
62
+ red = data[1, NUM_PIXELS]
63
+ green = data[1 + NUM_PIXELS, NUM_PIXELS]
64
+ blue = data[1 + NUM_PIXELS * 2, NUM_PIXELS]
65
+ block.call(index, i, label, red, green, blue)
66
+ i += 1
67
+ end
68
+ end
69
+ end
70
+ end
71
+ end
72
+
73
+ if __FILE__ == $0
74
+ require 'chunky_png'
75
+
76
+ c = CooCoo::Cifar::Batch.new
77
+ labels = CooCoo::Cifar::LabelSet.new
78
+ first, count = ARGV.collect(&:to_i)
79
+ first ||= 0
80
+ count ||= 32
81
+ puts("Showing #{count}, skipping #{first}")
82
+ c.each.drop(first).first(count).each do |batch, i, label, red, green, blue|
83
+ file = "#{batch}-#{i}-#{labels[label]}.png"
84
+ puts(file)
85
+ png = ChunkyPNG::Image.new(32, 32)
86
+ 32.times do |y|
87
+ 32.times do |x|
88
+ p = y * 32 + x
89
+ png[x, y] = ChunkyPNG::Color.rgb(red[p], green[p], blue[p])
90
+ end
91
+ end
92
+ png.save(file, :interlace => true)
93
+ end
94
+ end
@@ -0,0 +1,201 @@
1
+ require 'coo-coo'
2
+ require 'chunky_png'
3
+
4
+ class ImageStream
5
+ attr_reader :images
6
+
7
+ def initialize(*images)
8
+ @images = images.collect { |i| load_image(i) }
9
+ end
10
+
11
+ def load_image(path)
12
+ png = ChunkyPNG::Image.from_file(path)
13
+ pixels = CooCoo::Vector.new(png.width * png.height * 3)
14
+ png.pixels.each_slice(png.width).with_index do |row, i|
15
+ pixels[i * png.width * 3, png.width * 3] = row.
16
+ collect { |p| [ ChunkyPNG::Color.r(p),
17
+ ChunkyPNG::Color.g(p),
18
+ ChunkyPNG::Color.b(p)
19
+ ] }.
20
+ flatten
21
+ end
22
+
23
+ [ path, png.width, png.height, pixels / 256.0 ]
24
+ end
25
+
26
+ def size
27
+ @images.size
28
+ end
29
+
30
+ def channels
31
+ 3
32
+ end
33
+
34
+ def each(&block)
35
+ return to_enum(__method__) unless block_given?
36
+
37
+ @images.each do |img|
38
+ yield(*img)
39
+ end
40
+ end
41
+ end
42
+
43
+ class ImageSlicer
44
+ attr_reader :slice_width
45
+ attr_reader :slice_height
46
+
47
+ def initialize(num_slices, slice_width, slice_height, image_stream, chitters = 0)
48
+ @num_slices = num_slices
49
+ @slice_width = slice_width
50
+ @slice_height = slice_height
51
+ @chitters = chitters
52
+ @stream = image_stream
53
+ end
54
+
55
+ def size
56
+ @num_slices * @stream.size * @chitters
57
+ end
58
+
59
+ def channels
60
+ @stream.channels
61
+ end
62
+
63
+ def each(&block)
64
+ return to_enum(__method__) unless block_given?
65
+
66
+ @num_slices.times do |n|
67
+ @stream.each.with_index do |(path, width, height, pixels), i|
68
+ xr = rand(width)
69
+ yr = rand(height)
70
+ half_w = @slice_width / 2
71
+ half_h = @slice_height / 2
72
+ @chitters.times do |chitter|
73
+ x = xr
74
+ x += rand(@slice_width) - half_w if @chitters > 1
75
+ x = width - @slice_width if x + @slice_width > width
76
+ y = yr
77
+ y += rand(@slice_height) - half_h if @chitters > 1
78
+ y = height - @slice_height if y + @slice_height > height
79
+
80
+ slice = pixels.slice_2d(width * channels, height, x, y, @slice_width * channels, @slice_height)
81
+ yield(path, slice, x, y)
82
+ end
83
+ end
84
+ end
85
+ end
86
+ end
87
+
88
+ class TrainingSet
89
+ attr_reader :slicer, :batch_size
90
+
91
+ def initialize(slicer, batch_size)
92
+ @slicer = slicer
93
+ @batch_size = batch_size
94
+ end
95
+
96
+ def size
97
+ @slicer.size
98
+ end
99
+
100
+ def output_size
101
+ 3
102
+ end
103
+
104
+ def input_size
105
+ 2 * @slicer.channels * @slicer.slice_width * @slicer.slice_height
106
+ end
107
+
108
+ def interleave(a, b)
109
+ img = CooCoo::Vector.zeros(input_size)
110
+ stride = @slicer.slice_width * @slicer.channels
111
+
112
+ @slicer.slice_height.times do |y|
113
+ row = y * stride
114
+ img[2 * row, stride] = a[row, stride]
115
+ img[2 * row + stride, stride] = b[row, stride]
116
+ end
117
+
118
+ img
119
+ end
120
+
121
+ def each(&block)
122
+ return to_enum(__method__) unless block_given?
123
+
124
+ @slicer.each.each_slice(@batch_size) do |batch|
125
+ batch.shuffle.zip(batch.shuffle) do |a, b|
126
+ yield([target_for(a, b), interleave(a[1], b[1])])
127
+ end
128
+ end
129
+ end
130
+
131
+ def target_for(a, b)
132
+ v = CooCoo::Vector.zeros(output_size)
133
+ if a[0] == b[0]
134
+ v[0] = 1.0
135
+ d = CooCoo::Vector[[b[2] - a[2], b[3] - a[3]]]
136
+ d = d.magnitude
137
+ v[1] = @slicer.slice_width.to_f / d - 1.0
138
+ v[1] = 1.0 if v[1] > 1.0
139
+ v[1] = 0.0 if v[1] <= 0.0
140
+ v[2] = @slicer.slice_height.to_f / d - 1.0
141
+ v[2] = 1.0 if v[2] > 1.0
142
+ v[2] = 0.0 if v[2] <= 0.0
143
+ end
144
+ v
145
+ end
146
+ end
147
+
148
+ if $0 != __FILE__
149
+ require 'pathname'
150
+ require 'ostruct'
151
+
152
+ @options = OpenStruct.new
153
+ @options.slice_width = 32
154
+ @options.slice_height = 32
155
+ @options.num_slices = 1000
156
+ @options.cycles = 100
157
+ @options.images = []
158
+ @options.chitters = 4
159
+
160
+ @opts = CooCoo::OptionParser.new do |o|
161
+ o.banner = "Image Similarity Data Generator options"
162
+
163
+ o.on('--data-path PATH') do |path|
164
+ @options.images += Dir.glob(path)
165
+ end
166
+
167
+ o.on('--data-slice-width SIZE') do |n|
168
+ @options.slice_width = n.to_i
169
+ end
170
+
171
+ o.on('--data-slice-height SIZE') do |n|
172
+ @options.slice_height = n.to_i
173
+ end
174
+
175
+ o.on('--data-slices NUM') do |n|
176
+ @options.num_slices = n.to_i
177
+ end
178
+
179
+ o.on('--data-cycles NUM') do |n|
180
+ @options.cycles = n.to_i
181
+ end
182
+
183
+ o.on('--data-chitters NUM') do |n|
184
+ @options.chitters = n.to_i
185
+ end
186
+ end
187
+
188
+ def training_set()
189
+ stream = ImageStream.new(*@options.images)
190
+ slicer = ImageSlicer.new(@options.cycles,
191
+ @options.slice_width,
192
+ @options.slice_height,
193
+ stream,
194
+ @options.chitters)
195
+ training_set = TrainingSet.new(slicer, @options.num_slices.to_i)
196
+
197
+ training_set
198
+ end
199
+
200
+ [ method(:training_set), @opts ]
201
+ end
@@ -0,0 +1,57 @@
1
+ #!/bin/env ruby
2
+
3
+ require 'coo-coo'
4
+
5
+ average = Proc.new do |m|
6
+ e = m.each
7
+ Vector[[e.sum / e.count]]
8
+ end
9
+
10
+ xor = Proc.new do |m|
11
+ Vector[[m.to_a.flatten.inject(0) { |acc, n| acc ^ (255.0 * n).to_i } / 256.0]]
12
+ end
13
+
14
+ max = Proc.new do |m|
15
+ Vector[[m.each.max]]
16
+ end
17
+
18
+ def data(n, &block)
19
+ raise ArgumentError.new("Block not given") unless block_given?
20
+
21
+ out = Array.new
22
+ n.times do
23
+ m = Vector.rand(3)
24
+ out << [ block.(m), m ]
25
+ end
26
+
27
+ out
28
+ end
29
+
30
+ def print_prediction(model, input, expecting)
31
+ output = model.forward(input)
32
+ puts("#{input} -> #{output}, expecting #{expecting}, #{expecting - output}")
33
+ end
34
+
35
+ Random.srand(123)
36
+
37
+ f = max
38
+ training_data = data(1000, &f)
39
+ model = CooCoo::Network.new()
40
+ model.layer(CooCoo::Layer.new(3, 8))
41
+ model.layer(CooCoo::Layer.new(8, 8))
42
+ model.layer(CooCoo::Layer.new(8, 1))
43
+
44
+ puts("Training")
45
+ now = Time.now
46
+ model.train(network: model,
47
+ data: training_data,
48
+ learning_rate: 0.3,
49
+ batch_size: 200)
50
+ puts("\tElapsed #{(Time.now - now) / 60.0} min")
51
+
52
+ puts("Predicting:")
53
+
54
+ print_prediction(model, training_data.first[1], training_data.first[0])
55
+ print_prediction(model, Vector[[0.5, 0.75, 0.25]], f.(Vector[[0.5, 0.75, 0.25]]))
56
+ print_prediction(model, Vector[[0.25, 0.0, 0.0]], f.(Vector[[0.25, 0.0, 0.0]]))
57
+ print_prediction(model, Vector[[1.0, 0.0, 0.0]], f.(Vector[[1.0, 0.0, 0.0]]))
@@ -0,0 +1,365 @@
1
+ require 'pathname'
2
+ require 'net/http'
3
+ require 'zlib'
4
+ require 'coo-coo/image'
5
+
6
+ module MNist
7
+ PATH = Pathname.new(__FILE__)
8
+ MNIST_URIS = [ "http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz",
9
+ "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz",
10
+ "http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz",
11
+ "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz"
12
+ ]
13
+ TRAIN_LABELS_PATH = PATH.dirname.join('train-labels-idx1-ubyte')
14
+ TRAIN_IMAGES_PATH = PATH.dirname.join('train-images-idx3-ubyte')
15
+ TEST_LABELS_PATH = PATH.dirname.join('t10k-labels-idx1-ubyte')
16
+ TEST_IMAGES_PATH = PATH.dirname.join('t10k-images-idx3-ubyte')
17
+
18
+ Width = 28
19
+ Height = 28
20
+
21
+ module Fetcher
22
+ def fetch_gzip_url(url)
23
+ data = Net::HTTP.get(url)
24
+ Zlib::GzipReader.new(StringIO.new(data)).read
25
+ end
26
+
27
+ def fetch!
28
+ MNIST_URIS.each do |uri|
29
+ uri = URI.parse(uri)
30
+ path = PATH.dirname.join(File.basename(uri.path).sub(".gz", ""))
31
+ data = fetch_gzip_url(uri)
32
+ File.open(path, "w") do |f|
33
+ f.write(data)
34
+ end
35
+ end
36
+ end
37
+ end
38
+
39
+ class Example
40
+ attr_accessor :label
41
+ attr_accessor :pixels
42
+ attr_accessor :angle
43
+ attr_accessor :offset_x
44
+ attr_accessor :offset_y
45
+
46
+ def initialize(label, pixels, angle = 0, offset_x = 0, offset_y = 0)
47
+ @label = label
48
+ @pixels = pixels
49
+ @angle = angle
50
+ @offset_x = offset_x
51
+ @offset_y = offset_y
52
+ end
53
+
54
+ def pixel(x, y)
55
+ @pixels[y * MNist::Width + x] || 0
56
+ end
57
+
58
+ def to_ascii
59
+ s = ""
60
+ 28.times do |y|
61
+ 28.times do |x|
62
+ s += char_for_pixel(pixel(x, y))
63
+ end
64
+ s += "\n"
65
+ end
66
+ s
67
+ end
68
+
69
+ def each_pixel(&block)
70
+ return to_enum(__method__) unless block_given?
71
+ 28.times do |y|
72
+ 28.times do |x|
73
+ yield(pixel(x, y))
74
+ end
75
+ end
76
+ end
77
+
78
+ private
79
+ PixelValues = ' -+X#'
80
+
81
+ def char_for_pixel(p)
82
+ PixelValues[(p / 256.0 * PixelValues.length).to_i]
83
+ end
84
+ end
85
+
86
+ class DataStreamer
87
+ def initialize(labels_path, images_path)
88
+ @labels, @size = open_labels(labels_path)
89
+ @images, @image_size = open_images(images_path)
90
+ end
91
+
92
+ def close
93
+ @labels.close
94
+ @images.close
95
+ end
96
+
97
+ def size
98
+ @size
99
+ end
100
+
101
+ def next
102
+ label = next_label
103
+ pixels = next_image
104
+ if label && pixels
105
+ Example.new(label, pixels)
106
+ end
107
+ end
108
+
109
+ private
110
+
111
+ def open_labels(path)
112
+ f = File.open(path, "rb")
113
+ magic, number = f.read(4 * 2).unpack('NN')
114
+ raise RuntimeError.new("Invalid magic number #{magic} in #{path}") if magic != 0x801
115
+
116
+ [ f, number ]
117
+ end
118
+
119
+ def next_label
120
+ l = @labels.read(1)
121
+ if l
122
+ l.unpack('C').first
123
+ else
124
+ nil
125
+ end
126
+ end
127
+
128
+ def open_images(path)
129
+ f = File.open(path, "rb")
130
+ magic, num_images, height, width = f.read(4 * 4).unpack('NNNN')
131
+ raise RuntimeError.new("Invalid magic number #{magic} in #{path}") if magic != 0x803
132
+
133
+ [ f, width * height * 1 ]
134
+ end
135
+
136
+ def next_image
137
+ p = @images.read(@image_size)
138
+ if p
139
+ p.unpack('C' * @image_size)
140
+ else
141
+ nil
142
+ end
143
+ end
144
+ end
145
+
146
+ class DataStream
147
+ def initialize(labels_path = TRAIN_LABELS_PATH, images_path = TRAIN_IMAGES_PATH)
148
+ if (labels_path == TRAIN_LABELS_PATH && images_path == TRAIN_IMAGES_PATH) ||
149
+ (labels_path == TEST_LABELS_PATH && images_path == TEST_IMAGES_PATH)
150
+ if !File.exists?(labels_path) || !File.exists?(images_path)
151
+ Fetcher.fetch!
152
+ end
153
+ end
154
+
155
+ raise ArgumentError.new("File does not exist: #{labels_path}") unless File.exists?(labels_path)
156
+ raise ArgumentError.new("File does not exist: #{images_path}") unless File.exists?(images_path)
157
+
158
+ @labels_path = labels_path
159
+ @images_path = images_path
160
+
161
+ read_metadata
162
+ end
163
+
164
+ attr_reader :size
165
+ attr_reader :width
166
+ attr_reader :height
167
+
168
+ def each(&block)
169
+ return enum_for(__method__) unless block_given?
170
+
171
+ begin
172
+ streamer = DataStreamer.new(@labels_path, @images_path)
173
+
174
+ begin
175
+ ex = streamer.next
176
+ if ex
177
+ block.call(ex)
178
+ end
179
+ end until ex == nil
180
+ ensure
181
+ streamer.close
182
+ end
183
+ end
184
+
185
+ def to_enum
186
+ each
187
+ end
188
+
189
+ private
190
+ def read_metadata
191
+ read_size
192
+ read_dimensions
193
+ end
194
+
195
+ def read_dimensions
196
+ File.open(@images_path, "rb") do |f|
197
+ magic, num_images, height, width = f.read(4 * 4).unpack('NNNN')
198
+ raise RuntimeError.new("Invalid magic number #{magic} in #{path}") if magic != 0x803
199
+
200
+ @width = width
201
+ @height = height
202
+ end
203
+ end
204
+
205
+ def read_size
206
+ File.open(@labels_path, "rb") do |f|
207
+ magic, number = f.read(4 * 2).unpack('NN')
208
+ raise RuntimeError.new("Invalid magic number #{magic} in #{@labels_path}") if magic != 0x801
209
+
210
+ @size = number
211
+ end
212
+ end
213
+
214
+ public
215
+ class Rotator < Enumerator
216
+ def initialize(data, rotations, rotation_range, random = false)
217
+ @data = data.to_enum
218
+ @rotations = rotations
219
+ @rotation_range = rotation_range
220
+ @random = random
221
+
222
+ super() do |y|
223
+ loop do
224
+ example = @data.next
225
+ @rotations.times do |r|
226
+ t = if @random
227
+ rand
228
+ else
229
+ (r / @rotations.to_f)
230
+ end
231
+ theta = t * @rotation_range - @rotation_range / 2.0
232
+ img = rotate_pixels(example.pixels, theta)
233
+ y << Example.new(example.label, img.to_a.flatten, theta)
234
+ end
235
+ end
236
+ end
237
+ end
238
+
239
+ def wrap(enum)
240
+ self.class.new(enum, @rotations, @rotation_range, @random)
241
+ end
242
+
243
+ def drop(n)
244
+ wrap(@data.drop(n))
245
+ end
246
+
247
+ def rotate_pixels(pixels, theta)
248
+ rot = CooCoo::Image::Rotate.new(MNist::Width / 2, MNist::Height / 2, theta)
249
+ img = CooCoo::Image::Base.new(MNist::Width, MNist::Height, 1, pixels.to_a.flatten)
250
+ (img * rot)
251
+ end
252
+ end
253
+
254
+ class Translator < Enumerator
255
+ def initialize(data, num_translations, dx, dy, random = false)
256
+ @data = data.to_enum
257
+ @num_translations = num_translations
258
+ @dx = dx
259
+ @dy = dy
260
+ @random = random
261
+
262
+ super() do |yielder|
263
+ loop do
264
+ example = @data.next
265
+ @num_translations.times do |r|
266
+ x = if @random
267
+ rand
268
+ else
269
+ (r / @num_translations.to_f)
270
+ end
271
+ x = x * @dx - @dx / 2.0
272
+ y = if @random
273
+ rand
274
+ else
275
+ (r / @num_translations.to_f)
276
+ end
277
+ y = y * @dy - @dy / 2.0
278
+ img = translate_pixels(example.pixels, x, y)
279
+ yielder << Example.new(example.label, img.to_a.flatten, example.angle, x, y)
280
+ end
281
+ end
282
+ end
283
+ end
284
+
285
+ def wrap(enum)
286
+ self.class.new(enum, @num_translations, @dx, @dy, @random)
287
+ end
288
+
289
+ def drop(n)
290
+ wrap(@data.drop(n))
291
+ end
292
+
293
+ def translate_pixels(pixels, x, y)
294
+ transform = CooCoo::Image::Translate.new(x, y)
295
+ img = CooCoo::Image::Base.new(MNist::Width, MNist::Height, 1, pixels.to_a.flatten)
296
+ (img * transform)
297
+ end
298
+ end
299
+ end
300
+
301
+ class TrainingSet
302
+ def initialize(data_stream = MNist::DataStream.new(MNist::TRAIN_LABELS_PATH, MNist::TRAIN_IMAGES_PATH))
303
+ @stream = data_stream
304
+ end
305
+
306
+ def each(&block)
307
+ return to_enum(__method__) unless block
308
+
309
+ enum = @stream.each
310
+ loop do
311
+ example = enum.next
312
+
313
+ a = Array.new(10, 0.0)
314
+ a[example.label] = 1.0
315
+ m = [ CooCoo::Vector[a],
316
+ CooCoo::Vector[example.pixels] / 256.0
317
+ ]
318
+ #$stderr.puts("#{m[0]}\t#{m[1]}")
319
+ block.call(m)
320
+ end
321
+ end
322
+ end
323
+
324
+ class DataSet
325
+ attr_reader :examples
326
+
327
+ def initialize
328
+ @examples = Array.new
329
+ end
330
+
331
+ def load!(labels_path, images_path)
332
+ @examples = DataStream.new(labels_path, images_path).each.to_a
333
+ self
334
+ end
335
+ end
336
+
337
+ end
338
+
339
+ if __FILE__ == $0
340
+ def print_example(ex)
341
+ puts(ex.label)
342
+ puts(ex.to_ascii)
343
+ end
344
+
345
+ data = MNist::DataStream.new(MNist::TRAIN_LABELS_PATH, MNist::TRAIN_IMAGES_PATH)
346
+ i = 0
347
+ data.each.
348
+ group_by(&:label).
349
+ collect { |label, values| [ label, values.first ] }.
350
+ sort_by(&:first).
351
+ first(10).
352
+ each do |(label, e)|
353
+ puts(i)
354
+ print_example(e)
355
+ i += 1
356
+ break if i > 20
357
+ end
358
+
359
+ rot = MNist::DataStream::Rotator.new(data.each, 10, Math::PI, false)
360
+ rot.drop(10).first(10).each do |example|
361
+ print_example(example)
362
+ end
363
+
364
+ puts("#{data.size} total #{data.width}x#{data.height} images")
365
+ end