CooCoo 0.1.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/.gitignore +16 -0
- data/CooCoo.gemspec +47 -0
- data/Gemfile +4 -0
- data/Gemfile.lock +88 -0
- data/README.md +123 -0
- data/Rakefile +81 -0
- data/bin/cuda-dev-info +25 -0
- data/bin/cuda-free +28 -0
- data/bin/cuda-free-trend +7 -0
- data/bin/ffi-gen +267 -0
- data/bin/spec_runner_html.sh +42 -0
- data/bin/trainer +198 -0
- data/bin/trend-cost +13 -0
- data/examples/char-rnn.rb +405 -0
- data/examples/cifar/cifar.rb +94 -0
- data/examples/img-similarity.rb +201 -0
- data/examples/math_ops.rb +57 -0
- data/examples/mnist.rb +365 -0
- data/examples/mnist_classifier.rb +293 -0
- data/examples/mnist_dream.rb +214 -0
- data/examples/seeds.rb +268 -0
- data/examples/seeds_dataset.txt +210 -0
- data/examples/t10k-images-idx3-ubyte +0 -0
- data/examples/t10k-labels-idx1-ubyte +0 -0
- data/examples/train-images-idx3-ubyte +0 -0
- data/examples/train-labels-idx1-ubyte +0 -0
- data/ext/buffer/Rakefile +50 -0
- data/ext/buffer/buffer.pre.cu +727 -0
- data/ext/buffer/matrix.pre.cu +49 -0
- data/lib/CooCoo.rb +1 -0
- data/lib/coo-coo.rb +18 -0
- data/lib/coo-coo/activation_functions.rb +344 -0
- data/lib/coo-coo/consts.rb +5 -0
- data/lib/coo-coo/convolution.rb +298 -0
- data/lib/coo-coo/core_ext.rb +75 -0
- data/lib/coo-coo/cost_functions.rb +91 -0
- data/lib/coo-coo/cuda.rb +116 -0
- data/lib/coo-coo/cuda/device_buffer.rb +240 -0
- data/lib/coo-coo/cuda/device_buffer/ffi.rb +109 -0
- data/lib/coo-coo/cuda/error.rb +51 -0
- data/lib/coo-coo/cuda/host_buffer.rb +117 -0
- data/lib/coo-coo/cuda/runtime.rb +157 -0
- data/lib/coo-coo/cuda/vector.rb +315 -0
- data/lib/coo-coo/data_sources.rb +2 -0
- data/lib/coo-coo/data_sources/xournal.rb +25 -0
- data/lib/coo-coo/data_sources/xournal/bitmap_stream.rb +197 -0
- data/lib/coo-coo/data_sources/xournal/document.rb +377 -0
- data/lib/coo-coo/data_sources/xournal/loader.rb +144 -0
- data/lib/coo-coo/data_sources/xournal/renderer.rb +101 -0
- data/lib/coo-coo/data_sources/xournal/saver.rb +99 -0
- data/lib/coo-coo/data_sources/xournal/training_document.rb +78 -0
- data/lib/coo-coo/data_sources/xournal/training_document/constants.rb +15 -0
- data/lib/coo-coo/data_sources/xournal/training_document/document_maker.rb +89 -0
- data/lib/coo-coo/data_sources/xournal/training_document/document_reader.rb +105 -0
- data/lib/coo-coo/data_sources/xournal/training_document/example.rb +37 -0
- data/lib/coo-coo/data_sources/xournal/training_document/sets.rb +76 -0
- data/lib/coo-coo/debug.rb +8 -0
- data/lib/coo-coo/dot.rb +129 -0
- data/lib/coo-coo/drawing.rb +4 -0
- data/lib/coo-coo/drawing/cairo_canvas.rb +100 -0
- data/lib/coo-coo/drawing/canvas.rb +68 -0
- data/lib/coo-coo/drawing/chunky_canvas.rb +101 -0
- data/lib/coo-coo/drawing/sixel.rb +214 -0
- data/lib/coo-coo/enum.rb +17 -0
- data/lib/coo-coo/from_name.rb +58 -0
- data/lib/coo-coo/fully_connected_layer.rb +205 -0
- data/lib/coo-coo/generation_script.rb +38 -0
- data/lib/coo-coo/grapher.rb +140 -0
- data/lib/coo-coo/image.rb +286 -0
- data/lib/coo-coo/layer.rb +67 -0
- data/lib/coo-coo/layer_factory.rb +26 -0
- data/lib/coo-coo/linear_layer.rb +59 -0
- data/lib/coo-coo/math.rb +607 -0
- data/lib/coo-coo/math/abstract_vector.rb +121 -0
- data/lib/coo-coo/math/functions.rb +39 -0
- data/lib/coo-coo/math/interpolation.rb +7 -0
- data/lib/coo-coo/network.rb +264 -0
- data/lib/coo-coo/neuron.rb +112 -0
- data/lib/coo-coo/neuron_layer.rb +168 -0
- data/lib/coo-coo/option_parser.rb +18 -0
- data/lib/coo-coo/platform.rb +17 -0
- data/lib/coo-coo/progress_bar.rb +11 -0
- data/lib/coo-coo/recurrence/backend.rb +99 -0
- data/lib/coo-coo/recurrence/frontend.rb +101 -0
- data/lib/coo-coo/sequence.rb +187 -0
- data/lib/coo-coo/shell.rb +2 -0
- data/lib/coo-coo/temporal_network.rb +291 -0
- data/lib/coo-coo/trainer.rb +21 -0
- data/lib/coo-coo/trainer/base.rb +67 -0
- data/lib/coo-coo/trainer/batch.rb +82 -0
- data/lib/coo-coo/trainer/batch_stats.rb +27 -0
- data/lib/coo-coo/trainer/momentum_stochastic.rb +59 -0
- data/lib/coo-coo/trainer/stochastic.rb +47 -0
- data/lib/coo-coo/transformer.rb +272 -0
- data/lib/coo-coo/vector_layer.rb +194 -0
- data/lib/coo-coo/version.rb +3 -0
- data/lib/coo-coo/weight_deltas.rb +23 -0
- data/prototypes/convolution.rb +116 -0
- data/prototypes/linear_drop.rb +51 -0
- data/prototypes/recurrent_layers.rb +79 -0
- data/www/images/screamer.png +0 -0
- data/www/images/screamer.xcf +0 -0
- data/www/index.html +82 -0
- metadata +373 -0
@@ -0,0 +1,94 @@
|
|
1
|
+
module CooCoo::Cifar
|
2
|
+
ROOT = Pathname.new(__FILE__).dirname
|
3
|
+
BINARY_BATCHES = {
|
4
|
+
labels: ROOT.join("cifar-10-batches-bin", "batches.meta.txt"),
|
5
|
+
batches: 5.times.collect { |i| ROOT.join("cifar-10-batches-bin", "data_batch_#{i + 1}.bin") },
|
6
|
+
test_batch: ROOT.join("cifar-10-batches-bin", "test_batch.bin")
|
7
|
+
}
|
8
|
+
|
9
|
+
class LabelSet
|
10
|
+
def initialize(path = BINARY_BATCHES[:labels])
|
11
|
+
load!(path)
|
12
|
+
end
|
13
|
+
|
14
|
+
def load!(path)
|
15
|
+
@labels = File.read(path).split("\n")
|
16
|
+
end
|
17
|
+
|
18
|
+
def [](number)
|
19
|
+
@labels[number]
|
20
|
+
end
|
21
|
+
end
|
22
|
+
|
23
|
+
class Batch
|
24
|
+
WIDTH = 32
|
25
|
+
HEIGHT = 32
|
26
|
+
NUM_PIXELS = WIDTH * HEIGHT
|
27
|
+
NUM_CHANNELS = 3
|
28
|
+
|
29
|
+
attr_reader :paths
|
30
|
+
|
31
|
+
def initialize(*paths)
|
32
|
+
paths = BINARY_BATCHES[:batches] if paths.empty?
|
33
|
+
@paths = Array.new
|
34
|
+
paths.each do |p|
|
35
|
+
add(p)
|
36
|
+
end
|
37
|
+
end
|
38
|
+
|
39
|
+
def add(path)
|
40
|
+
@paths << path
|
41
|
+
self
|
42
|
+
end
|
43
|
+
|
44
|
+
def each(&block)
|
45
|
+
return to_enum(__method__) unless block_given?
|
46
|
+
|
47
|
+
@paths.each_with_index do |path, i|
|
48
|
+
enumerate_file(path, i, &block)
|
49
|
+
end
|
50
|
+
end
|
51
|
+
|
52
|
+
def enumerate_file(path, index, &block)
|
53
|
+
$stderr.puts("Loading #{path}")
|
54
|
+
File.open(path, 'rb') do |f|
|
55
|
+
i = 0
|
56
|
+
loop do
|
57
|
+
data = f.read(1 + NUM_PIXELS * NUM_CHANNELS)
|
58
|
+
break unless data
|
59
|
+
|
60
|
+
data = data.unpack('C*')
|
61
|
+
label = data[0]
|
62
|
+
red = data[1, NUM_PIXELS]
|
63
|
+
green = data[1 + NUM_PIXELS, NUM_PIXELS]
|
64
|
+
blue = data[1 + NUM_PIXELS * 2, NUM_PIXELS]
|
65
|
+
block.call(index, i, label, red, green, blue)
|
66
|
+
i += 1
|
67
|
+
end
|
68
|
+
end
|
69
|
+
end
|
70
|
+
end
|
71
|
+
end
|
72
|
+
|
73
|
+
if __FILE__ == $0
|
74
|
+
require 'chunky_png'
|
75
|
+
|
76
|
+
c = CooCoo::Cifar::Batch.new
|
77
|
+
labels = CooCoo::Cifar::LabelSet.new
|
78
|
+
first, count = ARGV.collect(&:to_i)
|
79
|
+
first ||= 0
|
80
|
+
count ||= 32
|
81
|
+
puts("Showing #{count}, skipping #{first}")
|
82
|
+
c.each.drop(first).first(count).each do |batch, i, label, red, green, blue|
|
83
|
+
file = "#{batch}-#{i}-#{labels[label]}.png"
|
84
|
+
puts(file)
|
85
|
+
png = ChunkyPNG::Image.new(32, 32)
|
86
|
+
32.times do |y|
|
87
|
+
32.times do |x|
|
88
|
+
p = y * 32 + x
|
89
|
+
png[x, y] = ChunkyPNG::Color.rgb(red[p], green[p], blue[p])
|
90
|
+
end
|
91
|
+
end
|
92
|
+
png.save(file, :interlace => true)
|
93
|
+
end
|
94
|
+
end
|
@@ -0,0 +1,201 @@
|
|
1
|
+
require 'coo-coo'
|
2
|
+
require 'chunky_png'
|
3
|
+
|
4
|
+
class ImageStream
|
5
|
+
attr_reader :images
|
6
|
+
|
7
|
+
def initialize(*images)
|
8
|
+
@images = images.collect { |i| load_image(i) }
|
9
|
+
end
|
10
|
+
|
11
|
+
def load_image(path)
|
12
|
+
png = ChunkyPNG::Image.from_file(path)
|
13
|
+
pixels = CooCoo::Vector.new(png.width * png.height * 3)
|
14
|
+
png.pixels.each_slice(png.width).with_index do |row, i|
|
15
|
+
pixels[i * png.width * 3, png.width * 3] = row.
|
16
|
+
collect { |p| [ ChunkyPNG::Color.r(p),
|
17
|
+
ChunkyPNG::Color.g(p),
|
18
|
+
ChunkyPNG::Color.b(p)
|
19
|
+
] }.
|
20
|
+
flatten
|
21
|
+
end
|
22
|
+
|
23
|
+
[ path, png.width, png.height, pixels / 256.0 ]
|
24
|
+
end
|
25
|
+
|
26
|
+
def size
|
27
|
+
@images.size
|
28
|
+
end
|
29
|
+
|
30
|
+
def channels
|
31
|
+
3
|
32
|
+
end
|
33
|
+
|
34
|
+
def each(&block)
|
35
|
+
return to_enum(__method__) unless block_given?
|
36
|
+
|
37
|
+
@images.each do |img|
|
38
|
+
yield(*img)
|
39
|
+
end
|
40
|
+
end
|
41
|
+
end
|
42
|
+
|
43
|
+
class ImageSlicer
|
44
|
+
attr_reader :slice_width
|
45
|
+
attr_reader :slice_height
|
46
|
+
|
47
|
+
def initialize(num_slices, slice_width, slice_height, image_stream, chitters = 0)
|
48
|
+
@num_slices = num_slices
|
49
|
+
@slice_width = slice_width
|
50
|
+
@slice_height = slice_height
|
51
|
+
@chitters = chitters
|
52
|
+
@stream = image_stream
|
53
|
+
end
|
54
|
+
|
55
|
+
def size
|
56
|
+
@num_slices * @stream.size * @chitters
|
57
|
+
end
|
58
|
+
|
59
|
+
def channels
|
60
|
+
@stream.channels
|
61
|
+
end
|
62
|
+
|
63
|
+
def each(&block)
|
64
|
+
return to_enum(__method__) unless block_given?
|
65
|
+
|
66
|
+
@num_slices.times do |n|
|
67
|
+
@stream.each.with_index do |(path, width, height, pixels), i|
|
68
|
+
xr = rand(width)
|
69
|
+
yr = rand(height)
|
70
|
+
half_w = @slice_width / 2
|
71
|
+
half_h = @slice_height / 2
|
72
|
+
@chitters.times do |chitter|
|
73
|
+
x = xr
|
74
|
+
x += rand(@slice_width) - half_w if @chitters > 1
|
75
|
+
x = width - @slice_width if x + @slice_width > width
|
76
|
+
y = yr
|
77
|
+
y += rand(@slice_height) - half_h if @chitters > 1
|
78
|
+
y = height - @slice_height if y + @slice_height > height
|
79
|
+
|
80
|
+
slice = pixels.slice_2d(width * channels, height, x, y, @slice_width * channels, @slice_height)
|
81
|
+
yield(path, slice, x, y)
|
82
|
+
end
|
83
|
+
end
|
84
|
+
end
|
85
|
+
end
|
86
|
+
end
|
87
|
+
|
88
|
+
class TrainingSet
|
89
|
+
attr_reader :slicer, :batch_size
|
90
|
+
|
91
|
+
def initialize(slicer, batch_size)
|
92
|
+
@slicer = slicer
|
93
|
+
@batch_size = batch_size
|
94
|
+
end
|
95
|
+
|
96
|
+
def size
|
97
|
+
@slicer.size
|
98
|
+
end
|
99
|
+
|
100
|
+
def output_size
|
101
|
+
3
|
102
|
+
end
|
103
|
+
|
104
|
+
def input_size
|
105
|
+
2 * @slicer.channels * @slicer.slice_width * @slicer.slice_height
|
106
|
+
end
|
107
|
+
|
108
|
+
def interleave(a, b)
|
109
|
+
img = CooCoo::Vector.zeros(input_size)
|
110
|
+
stride = @slicer.slice_width * @slicer.channels
|
111
|
+
|
112
|
+
@slicer.slice_height.times do |y|
|
113
|
+
row = y * stride
|
114
|
+
img[2 * row, stride] = a[row, stride]
|
115
|
+
img[2 * row + stride, stride] = b[row, stride]
|
116
|
+
end
|
117
|
+
|
118
|
+
img
|
119
|
+
end
|
120
|
+
|
121
|
+
def each(&block)
|
122
|
+
return to_enum(__method__) unless block_given?
|
123
|
+
|
124
|
+
@slicer.each.each_slice(@batch_size) do |batch|
|
125
|
+
batch.shuffle.zip(batch.shuffle) do |a, b|
|
126
|
+
yield([target_for(a, b), interleave(a[1], b[1])])
|
127
|
+
end
|
128
|
+
end
|
129
|
+
end
|
130
|
+
|
131
|
+
def target_for(a, b)
|
132
|
+
v = CooCoo::Vector.zeros(output_size)
|
133
|
+
if a[0] == b[0]
|
134
|
+
v[0] = 1.0
|
135
|
+
d = CooCoo::Vector[[b[2] - a[2], b[3] - a[3]]]
|
136
|
+
d = d.magnitude
|
137
|
+
v[1] = @slicer.slice_width.to_f / d - 1.0
|
138
|
+
v[1] = 1.0 if v[1] > 1.0
|
139
|
+
v[1] = 0.0 if v[1] <= 0.0
|
140
|
+
v[2] = @slicer.slice_height.to_f / d - 1.0
|
141
|
+
v[2] = 1.0 if v[2] > 1.0
|
142
|
+
v[2] = 0.0 if v[2] <= 0.0
|
143
|
+
end
|
144
|
+
v
|
145
|
+
end
|
146
|
+
end
|
147
|
+
|
148
|
+
if $0 != __FILE__
|
149
|
+
require 'pathname'
|
150
|
+
require 'ostruct'
|
151
|
+
|
152
|
+
@options = OpenStruct.new
|
153
|
+
@options.slice_width = 32
|
154
|
+
@options.slice_height = 32
|
155
|
+
@options.num_slices = 1000
|
156
|
+
@options.cycles = 100
|
157
|
+
@options.images = []
|
158
|
+
@options.chitters = 4
|
159
|
+
|
160
|
+
@opts = CooCoo::OptionParser.new do |o|
|
161
|
+
o.banner = "Image Similarity Data Generator options"
|
162
|
+
|
163
|
+
o.on('--data-path PATH') do |path|
|
164
|
+
@options.images += Dir.glob(path)
|
165
|
+
end
|
166
|
+
|
167
|
+
o.on('--data-slice-width SIZE') do |n|
|
168
|
+
@options.slice_width = n.to_i
|
169
|
+
end
|
170
|
+
|
171
|
+
o.on('--data-slice-height SIZE') do |n|
|
172
|
+
@options.slice_height = n.to_i
|
173
|
+
end
|
174
|
+
|
175
|
+
o.on('--data-slices NUM') do |n|
|
176
|
+
@options.num_slices = n.to_i
|
177
|
+
end
|
178
|
+
|
179
|
+
o.on('--data-cycles NUM') do |n|
|
180
|
+
@options.cycles = n.to_i
|
181
|
+
end
|
182
|
+
|
183
|
+
o.on('--data-chitters NUM') do |n|
|
184
|
+
@options.chitters = n.to_i
|
185
|
+
end
|
186
|
+
end
|
187
|
+
|
188
|
+
def training_set()
|
189
|
+
stream = ImageStream.new(*@options.images)
|
190
|
+
slicer = ImageSlicer.new(@options.cycles,
|
191
|
+
@options.slice_width,
|
192
|
+
@options.slice_height,
|
193
|
+
stream,
|
194
|
+
@options.chitters)
|
195
|
+
training_set = TrainingSet.new(slicer, @options.num_slices.to_i)
|
196
|
+
|
197
|
+
training_set
|
198
|
+
end
|
199
|
+
|
200
|
+
[ method(:training_set), @opts ]
|
201
|
+
end
|
@@ -0,0 +1,57 @@
|
|
1
|
+
#!/bin/env ruby
|
2
|
+
|
3
|
+
require 'coo-coo'
|
4
|
+
|
5
|
+
average = Proc.new do |m|
|
6
|
+
e = m.each
|
7
|
+
Vector[[e.sum / e.count]]
|
8
|
+
end
|
9
|
+
|
10
|
+
xor = Proc.new do |m|
|
11
|
+
Vector[[m.to_a.flatten.inject(0) { |acc, n| acc ^ (255.0 * n).to_i } / 256.0]]
|
12
|
+
end
|
13
|
+
|
14
|
+
max = Proc.new do |m|
|
15
|
+
Vector[[m.each.max]]
|
16
|
+
end
|
17
|
+
|
18
|
+
def data(n, &block)
|
19
|
+
raise ArgumentError.new("Block not given") unless block_given?
|
20
|
+
|
21
|
+
out = Array.new
|
22
|
+
n.times do
|
23
|
+
m = Vector.rand(3)
|
24
|
+
out << [ block.(m), m ]
|
25
|
+
end
|
26
|
+
|
27
|
+
out
|
28
|
+
end
|
29
|
+
|
30
|
+
def print_prediction(model, input, expecting)
|
31
|
+
output = model.forward(input)
|
32
|
+
puts("#{input} -> #{output}, expecting #{expecting}, #{expecting - output}")
|
33
|
+
end
|
34
|
+
|
35
|
+
Random.srand(123)
|
36
|
+
|
37
|
+
f = max
|
38
|
+
training_data = data(1000, &f)
|
39
|
+
model = CooCoo::Network.new()
|
40
|
+
model.layer(CooCoo::Layer.new(3, 8))
|
41
|
+
model.layer(CooCoo::Layer.new(8, 8))
|
42
|
+
model.layer(CooCoo::Layer.new(8, 1))
|
43
|
+
|
44
|
+
puts("Training")
|
45
|
+
now = Time.now
|
46
|
+
model.train(network: model,
|
47
|
+
data: training_data,
|
48
|
+
learning_rate: 0.3,
|
49
|
+
batch_size: 200)
|
50
|
+
puts("\tElapsed #{(Time.now - now) / 60.0} min")
|
51
|
+
|
52
|
+
puts("Predicting:")
|
53
|
+
|
54
|
+
print_prediction(model, training_data.first[1], training_data.first[0])
|
55
|
+
print_prediction(model, Vector[[0.5, 0.75, 0.25]], f.(Vector[[0.5, 0.75, 0.25]]))
|
56
|
+
print_prediction(model, Vector[[0.25, 0.0, 0.0]], f.(Vector[[0.25, 0.0, 0.0]]))
|
57
|
+
print_prediction(model, Vector[[1.0, 0.0, 0.0]], f.(Vector[[1.0, 0.0, 0.0]]))
|
data/examples/mnist.rb
ADDED
@@ -0,0 +1,365 @@
|
|
1
|
+
require 'pathname'
|
2
|
+
require 'net/http'
|
3
|
+
require 'zlib'
|
4
|
+
require 'coo-coo/image'
|
5
|
+
|
6
|
+
module MNist
|
7
|
+
PATH = Pathname.new(__FILE__)
|
8
|
+
MNIST_URIS = [ "http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz",
|
9
|
+
"http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz",
|
10
|
+
"http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz",
|
11
|
+
"http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz"
|
12
|
+
]
|
13
|
+
TRAIN_LABELS_PATH = PATH.dirname.join('train-labels-idx1-ubyte')
|
14
|
+
TRAIN_IMAGES_PATH = PATH.dirname.join('train-images-idx3-ubyte')
|
15
|
+
TEST_LABELS_PATH = PATH.dirname.join('t10k-labels-idx1-ubyte')
|
16
|
+
TEST_IMAGES_PATH = PATH.dirname.join('t10k-images-idx3-ubyte')
|
17
|
+
|
18
|
+
Width = 28
|
19
|
+
Height = 28
|
20
|
+
|
21
|
+
module Fetcher
|
22
|
+
def fetch_gzip_url(url)
|
23
|
+
data = Net::HTTP.get(url)
|
24
|
+
Zlib::GzipReader.new(StringIO.new(data)).read
|
25
|
+
end
|
26
|
+
|
27
|
+
def fetch!
|
28
|
+
MNIST_URIS.each do |uri|
|
29
|
+
uri = URI.parse(uri)
|
30
|
+
path = PATH.dirname.join(File.basename(uri.path).sub(".gz", ""))
|
31
|
+
data = fetch_gzip_url(uri)
|
32
|
+
File.open(path, "w") do |f|
|
33
|
+
f.write(data)
|
34
|
+
end
|
35
|
+
end
|
36
|
+
end
|
37
|
+
end
|
38
|
+
|
39
|
+
class Example
|
40
|
+
attr_accessor :label
|
41
|
+
attr_accessor :pixels
|
42
|
+
attr_accessor :angle
|
43
|
+
attr_accessor :offset_x
|
44
|
+
attr_accessor :offset_y
|
45
|
+
|
46
|
+
def initialize(label, pixels, angle = 0, offset_x = 0, offset_y = 0)
|
47
|
+
@label = label
|
48
|
+
@pixels = pixels
|
49
|
+
@angle = angle
|
50
|
+
@offset_x = offset_x
|
51
|
+
@offset_y = offset_y
|
52
|
+
end
|
53
|
+
|
54
|
+
def pixel(x, y)
|
55
|
+
@pixels[y * MNist::Width + x] || 0
|
56
|
+
end
|
57
|
+
|
58
|
+
def to_ascii
|
59
|
+
s = ""
|
60
|
+
28.times do |y|
|
61
|
+
28.times do |x|
|
62
|
+
s += char_for_pixel(pixel(x, y))
|
63
|
+
end
|
64
|
+
s += "\n"
|
65
|
+
end
|
66
|
+
s
|
67
|
+
end
|
68
|
+
|
69
|
+
def each_pixel(&block)
|
70
|
+
return to_enum(__method__) unless block_given?
|
71
|
+
28.times do |y|
|
72
|
+
28.times do |x|
|
73
|
+
yield(pixel(x, y))
|
74
|
+
end
|
75
|
+
end
|
76
|
+
end
|
77
|
+
|
78
|
+
private
|
79
|
+
PixelValues = ' -+X#'
|
80
|
+
|
81
|
+
def char_for_pixel(p)
|
82
|
+
PixelValues[(p / 256.0 * PixelValues.length).to_i]
|
83
|
+
end
|
84
|
+
end
|
85
|
+
|
86
|
+
class DataStreamer
|
87
|
+
def initialize(labels_path, images_path)
|
88
|
+
@labels, @size = open_labels(labels_path)
|
89
|
+
@images, @image_size = open_images(images_path)
|
90
|
+
end
|
91
|
+
|
92
|
+
def close
|
93
|
+
@labels.close
|
94
|
+
@images.close
|
95
|
+
end
|
96
|
+
|
97
|
+
def size
|
98
|
+
@size
|
99
|
+
end
|
100
|
+
|
101
|
+
def next
|
102
|
+
label = next_label
|
103
|
+
pixels = next_image
|
104
|
+
if label && pixels
|
105
|
+
Example.new(label, pixels)
|
106
|
+
end
|
107
|
+
end
|
108
|
+
|
109
|
+
private
|
110
|
+
|
111
|
+
def open_labels(path)
|
112
|
+
f = File.open(path, "rb")
|
113
|
+
magic, number = f.read(4 * 2).unpack('NN')
|
114
|
+
raise RuntimeError.new("Invalid magic number #{magic} in #{path}") if magic != 0x801
|
115
|
+
|
116
|
+
[ f, number ]
|
117
|
+
end
|
118
|
+
|
119
|
+
def next_label
|
120
|
+
l = @labels.read(1)
|
121
|
+
if l
|
122
|
+
l.unpack('C').first
|
123
|
+
else
|
124
|
+
nil
|
125
|
+
end
|
126
|
+
end
|
127
|
+
|
128
|
+
def open_images(path)
|
129
|
+
f = File.open(path, "rb")
|
130
|
+
magic, num_images, height, width = f.read(4 * 4).unpack('NNNN')
|
131
|
+
raise RuntimeError.new("Invalid magic number #{magic} in #{path}") if magic != 0x803
|
132
|
+
|
133
|
+
[ f, width * height * 1 ]
|
134
|
+
end
|
135
|
+
|
136
|
+
def next_image
|
137
|
+
p = @images.read(@image_size)
|
138
|
+
if p
|
139
|
+
p.unpack('C' * @image_size)
|
140
|
+
else
|
141
|
+
nil
|
142
|
+
end
|
143
|
+
end
|
144
|
+
end
|
145
|
+
|
146
|
+
class DataStream
|
147
|
+
def initialize(labels_path = TRAIN_LABELS_PATH, images_path = TRAIN_IMAGES_PATH)
|
148
|
+
if (labels_path == TRAIN_LABELS_PATH && images_path == TRAIN_IMAGES_PATH) ||
|
149
|
+
(labels_path == TEST_LABELS_PATH && images_path == TEST_IMAGES_PATH)
|
150
|
+
if !File.exists?(labels_path) || !File.exists?(images_path)
|
151
|
+
Fetcher.fetch!
|
152
|
+
end
|
153
|
+
end
|
154
|
+
|
155
|
+
raise ArgumentError.new("File does not exist: #{labels_path}") unless File.exists?(labels_path)
|
156
|
+
raise ArgumentError.new("File does not exist: #{images_path}") unless File.exists?(images_path)
|
157
|
+
|
158
|
+
@labels_path = labels_path
|
159
|
+
@images_path = images_path
|
160
|
+
|
161
|
+
read_metadata
|
162
|
+
end
|
163
|
+
|
164
|
+
attr_reader :size
|
165
|
+
attr_reader :width
|
166
|
+
attr_reader :height
|
167
|
+
|
168
|
+
def each(&block)
|
169
|
+
return enum_for(__method__) unless block_given?
|
170
|
+
|
171
|
+
begin
|
172
|
+
streamer = DataStreamer.new(@labels_path, @images_path)
|
173
|
+
|
174
|
+
begin
|
175
|
+
ex = streamer.next
|
176
|
+
if ex
|
177
|
+
block.call(ex)
|
178
|
+
end
|
179
|
+
end until ex == nil
|
180
|
+
ensure
|
181
|
+
streamer.close
|
182
|
+
end
|
183
|
+
end
|
184
|
+
|
185
|
+
def to_enum
|
186
|
+
each
|
187
|
+
end
|
188
|
+
|
189
|
+
private
|
190
|
+
def read_metadata
|
191
|
+
read_size
|
192
|
+
read_dimensions
|
193
|
+
end
|
194
|
+
|
195
|
+
def read_dimensions
|
196
|
+
File.open(@images_path, "rb") do |f|
|
197
|
+
magic, num_images, height, width = f.read(4 * 4).unpack('NNNN')
|
198
|
+
raise RuntimeError.new("Invalid magic number #{magic} in #{path}") if magic != 0x803
|
199
|
+
|
200
|
+
@width = width
|
201
|
+
@height = height
|
202
|
+
end
|
203
|
+
end
|
204
|
+
|
205
|
+
def read_size
|
206
|
+
File.open(@labels_path, "rb") do |f|
|
207
|
+
magic, number = f.read(4 * 2).unpack('NN')
|
208
|
+
raise RuntimeError.new("Invalid magic number #{magic} in #{@labels_path}") if magic != 0x801
|
209
|
+
|
210
|
+
@size = number
|
211
|
+
end
|
212
|
+
end
|
213
|
+
|
214
|
+
public
|
215
|
+
class Rotator < Enumerator
|
216
|
+
def initialize(data, rotations, rotation_range, random = false)
|
217
|
+
@data = data.to_enum
|
218
|
+
@rotations = rotations
|
219
|
+
@rotation_range = rotation_range
|
220
|
+
@random = random
|
221
|
+
|
222
|
+
super() do |y|
|
223
|
+
loop do
|
224
|
+
example = @data.next
|
225
|
+
@rotations.times do |r|
|
226
|
+
t = if @random
|
227
|
+
rand
|
228
|
+
else
|
229
|
+
(r / @rotations.to_f)
|
230
|
+
end
|
231
|
+
theta = t * @rotation_range - @rotation_range / 2.0
|
232
|
+
img = rotate_pixels(example.pixels, theta)
|
233
|
+
y << Example.new(example.label, img.to_a.flatten, theta)
|
234
|
+
end
|
235
|
+
end
|
236
|
+
end
|
237
|
+
end
|
238
|
+
|
239
|
+
def wrap(enum)
|
240
|
+
self.class.new(enum, @rotations, @rotation_range, @random)
|
241
|
+
end
|
242
|
+
|
243
|
+
def drop(n)
|
244
|
+
wrap(@data.drop(n))
|
245
|
+
end
|
246
|
+
|
247
|
+
def rotate_pixels(pixels, theta)
|
248
|
+
rot = CooCoo::Image::Rotate.new(MNist::Width / 2, MNist::Height / 2, theta)
|
249
|
+
img = CooCoo::Image::Base.new(MNist::Width, MNist::Height, 1, pixels.to_a.flatten)
|
250
|
+
(img * rot)
|
251
|
+
end
|
252
|
+
end
|
253
|
+
|
254
|
+
class Translator < Enumerator
|
255
|
+
def initialize(data, num_translations, dx, dy, random = false)
|
256
|
+
@data = data.to_enum
|
257
|
+
@num_translations = num_translations
|
258
|
+
@dx = dx
|
259
|
+
@dy = dy
|
260
|
+
@random = random
|
261
|
+
|
262
|
+
super() do |yielder|
|
263
|
+
loop do
|
264
|
+
example = @data.next
|
265
|
+
@num_translations.times do |r|
|
266
|
+
x = if @random
|
267
|
+
rand
|
268
|
+
else
|
269
|
+
(r / @num_translations.to_f)
|
270
|
+
end
|
271
|
+
x = x * @dx - @dx / 2.0
|
272
|
+
y = if @random
|
273
|
+
rand
|
274
|
+
else
|
275
|
+
(r / @num_translations.to_f)
|
276
|
+
end
|
277
|
+
y = y * @dy - @dy / 2.0
|
278
|
+
img = translate_pixels(example.pixels, x, y)
|
279
|
+
yielder << Example.new(example.label, img.to_a.flatten, example.angle, x, y)
|
280
|
+
end
|
281
|
+
end
|
282
|
+
end
|
283
|
+
end
|
284
|
+
|
285
|
+
def wrap(enum)
|
286
|
+
self.class.new(enum, @num_translations, @dx, @dy, @random)
|
287
|
+
end
|
288
|
+
|
289
|
+
def drop(n)
|
290
|
+
wrap(@data.drop(n))
|
291
|
+
end
|
292
|
+
|
293
|
+
def translate_pixels(pixels, x, y)
|
294
|
+
transform = CooCoo::Image::Translate.new(x, y)
|
295
|
+
img = CooCoo::Image::Base.new(MNist::Width, MNist::Height, 1, pixels.to_a.flatten)
|
296
|
+
(img * transform)
|
297
|
+
end
|
298
|
+
end
|
299
|
+
end
|
300
|
+
|
301
|
+
class TrainingSet
|
302
|
+
def initialize(data_stream = MNist::DataStream.new(MNist::TRAIN_LABELS_PATH, MNist::TRAIN_IMAGES_PATH))
|
303
|
+
@stream = data_stream
|
304
|
+
end
|
305
|
+
|
306
|
+
def each(&block)
|
307
|
+
return to_enum(__method__) unless block
|
308
|
+
|
309
|
+
enum = @stream.each
|
310
|
+
loop do
|
311
|
+
example = enum.next
|
312
|
+
|
313
|
+
a = Array.new(10, 0.0)
|
314
|
+
a[example.label] = 1.0
|
315
|
+
m = [ CooCoo::Vector[a],
|
316
|
+
CooCoo::Vector[example.pixels] / 256.0
|
317
|
+
]
|
318
|
+
#$stderr.puts("#{m[0]}\t#{m[1]}")
|
319
|
+
block.call(m)
|
320
|
+
end
|
321
|
+
end
|
322
|
+
end
|
323
|
+
|
324
|
+
class DataSet
|
325
|
+
attr_reader :examples
|
326
|
+
|
327
|
+
def initialize
|
328
|
+
@examples = Array.new
|
329
|
+
end
|
330
|
+
|
331
|
+
def load!(labels_path, images_path)
|
332
|
+
@examples = DataStream.new(labels_path, images_path).each.to_a
|
333
|
+
self
|
334
|
+
end
|
335
|
+
end
|
336
|
+
|
337
|
+
end
|
338
|
+
|
339
|
+
if __FILE__ == $0
|
340
|
+
def print_example(ex)
|
341
|
+
puts(ex.label)
|
342
|
+
puts(ex.to_ascii)
|
343
|
+
end
|
344
|
+
|
345
|
+
data = MNist::DataStream.new(MNist::TRAIN_LABELS_PATH, MNist::TRAIN_IMAGES_PATH)
|
346
|
+
i = 0
|
347
|
+
data.each.
|
348
|
+
group_by(&:label).
|
349
|
+
collect { |label, values| [ label, values.first ] }.
|
350
|
+
sort_by(&:first).
|
351
|
+
first(10).
|
352
|
+
each do |(label, e)|
|
353
|
+
puts(i)
|
354
|
+
print_example(e)
|
355
|
+
i += 1
|
356
|
+
break if i > 20
|
357
|
+
end
|
358
|
+
|
359
|
+
rot = MNist::DataStream::Rotator.new(data.each, 10, Math::PI, false)
|
360
|
+
rot.drop(10).first(10).each do |example|
|
361
|
+
print_example(example)
|
362
|
+
end
|
363
|
+
|
364
|
+
puts("#{data.size} total #{data.width}x#{data.height} images")
|
365
|
+
end
|