CooCoo 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/.gitignore +16 -0
- data/CooCoo.gemspec +47 -0
- data/Gemfile +4 -0
- data/Gemfile.lock +88 -0
- data/README.md +123 -0
- data/Rakefile +81 -0
- data/bin/cuda-dev-info +25 -0
- data/bin/cuda-free +28 -0
- data/bin/cuda-free-trend +7 -0
- data/bin/ffi-gen +267 -0
- data/bin/spec_runner_html.sh +42 -0
- data/bin/trainer +198 -0
- data/bin/trend-cost +13 -0
- data/examples/char-rnn.rb +405 -0
- data/examples/cifar/cifar.rb +94 -0
- data/examples/img-similarity.rb +201 -0
- data/examples/math_ops.rb +57 -0
- data/examples/mnist.rb +365 -0
- data/examples/mnist_classifier.rb +293 -0
- data/examples/mnist_dream.rb +214 -0
- data/examples/seeds.rb +268 -0
- data/examples/seeds_dataset.txt +210 -0
- data/examples/t10k-images-idx3-ubyte +0 -0
- data/examples/t10k-labels-idx1-ubyte +0 -0
- data/examples/train-images-idx3-ubyte +0 -0
- data/examples/train-labels-idx1-ubyte +0 -0
- data/ext/buffer/Rakefile +50 -0
- data/ext/buffer/buffer.pre.cu +727 -0
- data/ext/buffer/matrix.pre.cu +49 -0
- data/lib/CooCoo.rb +1 -0
- data/lib/coo-coo.rb +18 -0
- data/lib/coo-coo/activation_functions.rb +344 -0
- data/lib/coo-coo/consts.rb +5 -0
- data/lib/coo-coo/convolution.rb +298 -0
- data/lib/coo-coo/core_ext.rb +75 -0
- data/lib/coo-coo/cost_functions.rb +91 -0
- data/lib/coo-coo/cuda.rb +116 -0
- data/lib/coo-coo/cuda/device_buffer.rb +240 -0
- data/lib/coo-coo/cuda/device_buffer/ffi.rb +109 -0
- data/lib/coo-coo/cuda/error.rb +51 -0
- data/lib/coo-coo/cuda/host_buffer.rb +117 -0
- data/lib/coo-coo/cuda/runtime.rb +157 -0
- data/lib/coo-coo/cuda/vector.rb +315 -0
- data/lib/coo-coo/data_sources.rb +2 -0
- data/lib/coo-coo/data_sources/xournal.rb +25 -0
- data/lib/coo-coo/data_sources/xournal/bitmap_stream.rb +197 -0
- data/lib/coo-coo/data_sources/xournal/document.rb +377 -0
- data/lib/coo-coo/data_sources/xournal/loader.rb +144 -0
- data/lib/coo-coo/data_sources/xournal/renderer.rb +101 -0
- data/lib/coo-coo/data_sources/xournal/saver.rb +99 -0
- data/lib/coo-coo/data_sources/xournal/training_document.rb +78 -0
- data/lib/coo-coo/data_sources/xournal/training_document/constants.rb +15 -0
- data/lib/coo-coo/data_sources/xournal/training_document/document_maker.rb +89 -0
- data/lib/coo-coo/data_sources/xournal/training_document/document_reader.rb +105 -0
- data/lib/coo-coo/data_sources/xournal/training_document/example.rb +37 -0
- data/lib/coo-coo/data_sources/xournal/training_document/sets.rb +76 -0
- data/lib/coo-coo/debug.rb +8 -0
- data/lib/coo-coo/dot.rb +129 -0
- data/lib/coo-coo/drawing.rb +4 -0
- data/lib/coo-coo/drawing/cairo_canvas.rb +100 -0
- data/lib/coo-coo/drawing/canvas.rb +68 -0
- data/lib/coo-coo/drawing/chunky_canvas.rb +101 -0
- data/lib/coo-coo/drawing/sixel.rb +214 -0
- data/lib/coo-coo/enum.rb +17 -0
- data/lib/coo-coo/from_name.rb +58 -0
- data/lib/coo-coo/fully_connected_layer.rb +205 -0
- data/lib/coo-coo/generation_script.rb +38 -0
- data/lib/coo-coo/grapher.rb +140 -0
- data/lib/coo-coo/image.rb +286 -0
- data/lib/coo-coo/layer.rb +67 -0
- data/lib/coo-coo/layer_factory.rb +26 -0
- data/lib/coo-coo/linear_layer.rb +59 -0
- data/lib/coo-coo/math.rb +607 -0
- data/lib/coo-coo/math/abstract_vector.rb +121 -0
- data/lib/coo-coo/math/functions.rb +39 -0
- data/lib/coo-coo/math/interpolation.rb +7 -0
- data/lib/coo-coo/network.rb +264 -0
- data/lib/coo-coo/neuron.rb +112 -0
- data/lib/coo-coo/neuron_layer.rb +168 -0
- data/lib/coo-coo/option_parser.rb +18 -0
- data/lib/coo-coo/platform.rb +17 -0
- data/lib/coo-coo/progress_bar.rb +11 -0
- data/lib/coo-coo/recurrence/backend.rb +99 -0
- data/lib/coo-coo/recurrence/frontend.rb +101 -0
- data/lib/coo-coo/sequence.rb +187 -0
- data/lib/coo-coo/shell.rb +2 -0
- data/lib/coo-coo/temporal_network.rb +291 -0
- data/lib/coo-coo/trainer.rb +21 -0
- data/lib/coo-coo/trainer/base.rb +67 -0
- data/lib/coo-coo/trainer/batch.rb +82 -0
- data/lib/coo-coo/trainer/batch_stats.rb +27 -0
- data/lib/coo-coo/trainer/momentum_stochastic.rb +59 -0
- data/lib/coo-coo/trainer/stochastic.rb +47 -0
- data/lib/coo-coo/transformer.rb +272 -0
- data/lib/coo-coo/vector_layer.rb +194 -0
- data/lib/coo-coo/version.rb +3 -0
- data/lib/coo-coo/weight_deltas.rb +23 -0
- data/prototypes/convolution.rb +116 -0
- data/prototypes/linear_drop.rb +51 -0
- data/prototypes/recurrent_layers.rb +79 -0
- data/www/images/screamer.png +0 -0
- data/www/images/screamer.xcf +0 -0
- data/www/index.html +82 -0
- metadata +373 -0
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
module CooCoo::Cifar
|
|
2
|
+
ROOT = Pathname.new(__FILE__).dirname
|
|
3
|
+
BINARY_BATCHES = {
|
|
4
|
+
labels: ROOT.join("cifar-10-batches-bin", "batches.meta.txt"),
|
|
5
|
+
batches: 5.times.collect { |i| ROOT.join("cifar-10-batches-bin", "data_batch_#{i + 1}.bin") },
|
|
6
|
+
test_batch: ROOT.join("cifar-10-batches-bin", "test_batch.bin")
|
|
7
|
+
}
|
|
8
|
+
|
|
9
|
+
class LabelSet
|
|
10
|
+
def initialize(path = BINARY_BATCHES[:labels])
|
|
11
|
+
load!(path)
|
|
12
|
+
end
|
|
13
|
+
|
|
14
|
+
def load!(path)
|
|
15
|
+
@labels = File.read(path).split("\n")
|
|
16
|
+
end
|
|
17
|
+
|
|
18
|
+
def [](number)
|
|
19
|
+
@labels[number]
|
|
20
|
+
end
|
|
21
|
+
end
|
|
22
|
+
|
|
23
|
+
class Batch
|
|
24
|
+
WIDTH = 32
|
|
25
|
+
HEIGHT = 32
|
|
26
|
+
NUM_PIXELS = WIDTH * HEIGHT
|
|
27
|
+
NUM_CHANNELS = 3
|
|
28
|
+
|
|
29
|
+
attr_reader :paths
|
|
30
|
+
|
|
31
|
+
def initialize(*paths)
|
|
32
|
+
paths = BINARY_BATCHES[:batches] if paths.empty?
|
|
33
|
+
@paths = Array.new
|
|
34
|
+
paths.each do |p|
|
|
35
|
+
add(p)
|
|
36
|
+
end
|
|
37
|
+
end
|
|
38
|
+
|
|
39
|
+
def add(path)
|
|
40
|
+
@paths << path
|
|
41
|
+
self
|
|
42
|
+
end
|
|
43
|
+
|
|
44
|
+
def each(&block)
|
|
45
|
+
return to_enum(__method__) unless block_given?
|
|
46
|
+
|
|
47
|
+
@paths.each_with_index do |path, i|
|
|
48
|
+
enumerate_file(path, i, &block)
|
|
49
|
+
end
|
|
50
|
+
end
|
|
51
|
+
|
|
52
|
+
def enumerate_file(path, index, &block)
|
|
53
|
+
$stderr.puts("Loading #{path}")
|
|
54
|
+
File.open(path, 'rb') do |f|
|
|
55
|
+
i = 0
|
|
56
|
+
loop do
|
|
57
|
+
data = f.read(1 + NUM_PIXELS * NUM_CHANNELS)
|
|
58
|
+
break unless data
|
|
59
|
+
|
|
60
|
+
data = data.unpack('C*')
|
|
61
|
+
label = data[0]
|
|
62
|
+
red = data[1, NUM_PIXELS]
|
|
63
|
+
green = data[1 + NUM_PIXELS, NUM_PIXELS]
|
|
64
|
+
blue = data[1 + NUM_PIXELS * 2, NUM_PIXELS]
|
|
65
|
+
block.call(index, i, label, red, green, blue)
|
|
66
|
+
i += 1
|
|
67
|
+
end
|
|
68
|
+
end
|
|
69
|
+
end
|
|
70
|
+
end
|
|
71
|
+
end
|
|
72
|
+
|
|
73
|
+
if __FILE__ == $0
|
|
74
|
+
require 'chunky_png'
|
|
75
|
+
|
|
76
|
+
c = CooCoo::Cifar::Batch.new
|
|
77
|
+
labels = CooCoo::Cifar::LabelSet.new
|
|
78
|
+
first, count = ARGV.collect(&:to_i)
|
|
79
|
+
first ||= 0
|
|
80
|
+
count ||= 32
|
|
81
|
+
puts("Showing #{count}, skipping #{first}")
|
|
82
|
+
c.each.drop(first).first(count).each do |batch, i, label, red, green, blue|
|
|
83
|
+
file = "#{batch}-#{i}-#{labels[label]}.png"
|
|
84
|
+
puts(file)
|
|
85
|
+
png = ChunkyPNG::Image.new(32, 32)
|
|
86
|
+
32.times do |y|
|
|
87
|
+
32.times do |x|
|
|
88
|
+
p = y * 32 + x
|
|
89
|
+
png[x, y] = ChunkyPNG::Color.rgb(red[p], green[p], blue[p])
|
|
90
|
+
end
|
|
91
|
+
end
|
|
92
|
+
png.save(file, :interlace => true)
|
|
93
|
+
end
|
|
94
|
+
end
|
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
require 'coo-coo'
|
|
2
|
+
require 'chunky_png'
|
|
3
|
+
|
|
4
|
+
class ImageStream
|
|
5
|
+
attr_reader :images
|
|
6
|
+
|
|
7
|
+
def initialize(*images)
|
|
8
|
+
@images = images.collect { |i| load_image(i) }
|
|
9
|
+
end
|
|
10
|
+
|
|
11
|
+
def load_image(path)
|
|
12
|
+
png = ChunkyPNG::Image.from_file(path)
|
|
13
|
+
pixels = CooCoo::Vector.new(png.width * png.height * 3)
|
|
14
|
+
png.pixels.each_slice(png.width).with_index do |row, i|
|
|
15
|
+
pixels[i * png.width * 3, png.width * 3] = row.
|
|
16
|
+
collect { |p| [ ChunkyPNG::Color.r(p),
|
|
17
|
+
ChunkyPNG::Color.g(p),
|
|
18
|
+
ChunkyPNG::Color.b(p)
|
|
19
|
+
] }.
|
|
20
|
+
flatten
|
|
21
|
+
end
|
|
22
|
+
|
|
23
|
+
[ path, png.width, png.height, pixels / 256.0 ]
|
|
24
|
+
end
|
|
25
|
+
|
|
26
|
+
def size
|
|
27
|
+
@images.size
|
|
28
|
+
end
|
|
29
|
+
|
|
30
|
+
def channels
|
|
31
|
+
3
|
|
32
|
+
end
|
|
33
|
+
|
|
34
|
+
def each(&block)
|
|
35
|
+
return to_enum(__method__) unless block_given?
|
|
36
|
+
|
|
37
|
+
@images.each do |img|
|
|
38
|
+
yield(*img)
|
|
39
|
+
end
|
|
40
|
+
end
|
|
41
|
+
end
|
|
42
|
+
|
|
43
|
+
class ImageSlicer
|
|
44
|
+
attr_reader :slice_width
|
|
45
|
+
attr_reader :slice_height
|
|
46
|
+
|
|
47
|
+
def initialize(num_slices, slice_width, slice_height, image_stream, chitters = 0)
|
|
48
|
+
@num_slices = num_slices
|
|
49
|
+
@slice_width = slice_width
|
|
50
|
+
@slice_height = slice_height
|
|
51
|
+
@chitters = chitters
|
|
52
|
+
@stream = image_stream
|
|
53
|
+
end
|
|
54
|
+
|
|
55
|
+
def size
|
|
56
|
+
@num_slices * @stream.size * @chitters
|
|
57
|
+
end
|
|
58
|
+
|
|
59
|
+
def channels
|
|
60
|
+
@stream.channels
|
|
61
|
+
end
|
|
62
|
+
|
|
63
|
+
def each(&block)
|
|
64
|
+
return to_enum(__method__) unless block_given?
|
|
65
|
+
|
|
66
|
+
@num_slices.times do |n|
|
|
67
|
+
@stream.each.with_index do |(path, width, height, pixels), i|
|
|
68
|
+
xr = rand(width)
|
|
69
|
+
yr = rand(height)
|
|
70
|
+
half_w = @slice_width / 2
|
|
71
|
+
half_h = @slice_height / 2
|
|
72
|
+
@chitters.times do |chitter|
|
|
73
|
+
x = xr
|
|
74
|
+
x += rand(@slice_width) - half_w if @chitters > 1
|
|
75
|
+
x = width - @slice_width if x + @slice_width > width
|
|
76
|
+
y = yr
|
|
77
|
+
y += rand(@slice_height) - half_h if @chitters > 1
|
|
78
|
+
y = height - @slice_height if y + @slice_height > height
|
|
79
|
+
|
|
80
|
+
slice = pixels.slice_2d(width * channels, height, x, y, @slice_width * channels, @slice_height)
|
|
81
|
+
yield(path, slice, x, y)
|
|
82
|
+
end
|
|
83
|
+
end
|
|
84
|
+
end
|
|
85
|
+
end
|
|
86
|
+
end
|
|
87
|
+
|
|
88
|
+
class TrainingSet
|
|
89
|
+
attr_reader :slicer, :batch_size
|
|
90
|
+
|
|
91
|
+
def initialize(slicer, batch_size)
|
|
92
|
+
@slicer = slicer
|
|
93
|
+
@batch_size = batch_size
|
|
94
|
+
end
|
|
95
|
+
|
|
96
|
+
def size
|
|
97
|
+
@slicer.size
|
|
98
|
+
end
|
|
99
|
+
|
|
100
|
+
def output_size
|
|
101
|
+
3
|
|
102
|
+
end
|
|
103
|
+
|
|
104
|
+
def input_size
|
|
105
|
+
2 * @slicer.channels * @slicer.slice_width * @slicer.slice_height
|
|
106
|
+
end
|
|
107
|
+
|
|
108
|
+
def interleave(a, b)
|
|
109
|
+
img = CooCoo::Vector.zeros(input_size)
|
|
110
|
+
stride = @slicer.slice_width * @slicer.channels
|
|
111
|
+
|
|
112
|
+
@slicer.slice_height.times do |y|
|
|
113
|
+
row = y * stride
|
|
114
|
+
img[2 * row, stride] = a[row, stride]
|
|
115
|
+
img[2 * row + stride, stride] = b[row, stride]
|
|
116
|
+
end
|
|
117
|
+
|
|
118
|
+
img
|
|
119
|
+
end
|
|
120
|
+
|
|
121
|
+
def each(&block)
|
|
122
|
+
return to_enum(__method__) unless block_given?
|
|
123
|
+
|
|
124
|
+
@slicer.each.each_slice(@batch_size) do |batch|
|
|
125
|
+
batch.shuffle.zip(batch.shuffle) do |a, b|
|
|
126
|
+
yield([target_for(a, b), interleave(a[1], b[1])])
|
|
127
|
+
end
|
|
128
|
+
end
|
|
129
|
+
end
|
|
130
|
+
|
|
131
|
+
def target_for(a, b)
|
|
132
|
+
v = CooCoo::Vector.zeros(output_size)
|
|
133
|
+
if a[0] == b[0]
|
|
134
|
+
v[0] = 1.0
|
|
135
|
+
d = CooCoo::Vector[[b[2] - a[2], b[3] - a[3]]]
|
|
136
|
+
d = d.magnitude
|
|
137
|
+
v[1] = @slicer.slice_width.to_f / d - 1.0
|
|
138
|
+
v[1] = 1.0 if v[1] > 1.0
|
|
139
|
+
v[1] = 0.0 if v[1] <= 0.0
|
|
140
|
+
v[2] = @slicer.slice_height.to_f / d - 1.0
|
|
141
|
+
v[2] = 1.0 if v[2] > 1.0
|
|
142
|
+
v[2] = 0.0 if v[2] <= 0.0
|
|
143
|
+
end
|
|
144
|
+
v
|
|
145
|
+
end
|
|
146
|
+
end
|
|
147
|
+
|
|
148
|
+
if $0 != __FILE__
|
|
149
|
+
require 'pathname'
|
|
150
|
+
require 'ostruct'
|
|
151
|
+
|
|
152
|
+
@options = OpenStruct.new
|
|
153
|
+
@options.slice_width = 32
|
|
154
|
+
@options.slice_height = 32
|
|
155
|
+
@options.num_slices = 1000
|
|
156
|
+
@options.cycles = 100
|
|
157
|
+
@options.images = []
|
|
158
|
+
@options.chitters = 4
|
|
159
|
+
|
|
160
|
+
@opts = CooCoo::OptionParser.new do |o|
|
|
161
|
+
o.banner = "Image Similarity Data Generator options"
|
|
162
|
+
|
|
163
|
+
o.on('--data-path PATH') do |path|
|
|
164
|
+
@options.images += Dir.glob(path)
|
|
165
|
+
end
|
|
166
|
+
|
|
167
|
+
o.on('--data-slice-width SIZE') do |n|
|
|
168
|
+
@options.slice_width = n.to_i
|
|
169
|
+
end
|
|
170
|
+
|
|
171
|
+
o.on('--data-slice-height SIZE') do |n|
|
|
172
|
+
@options.slice_height = n.to_i
|
|
173
|
+
end
|
|
174
|
+
|
|
175
|
+
o.on('--data-slices NUM') do |n|
|
|
176
|
+
@options.num_slices = n.to_i
|
|
177
|
+
end
|
|
178
|
+
|
|
179
|
+
o.on('--data-cycles NUM') do |n|
|
|
180
|
+
@options.cycles = n.to_i
|
|
181
|
+
end
|
|
182
|
+
|
|
183
|
+
o.on('--data-chitters NUM') do |n|
|
|
184
|
+
@options.chitters = n.to_i
|
|
185
|
+
end
|
|
186
|
+
end
|
|
187
|
+
|
|
188
|
+
def training_set()
|
|
189
|
+
stream = ImageStream.new(*@options.images)
|
|
190
|
+
slicer = ImageSlicer.new(@options.cycles,
|
|
191
|
+
@options.slice_width,
|
|
192
|
+
@options.slice_height,
|
|
193
|
+
stream,
|
|
194
|
+
@options.chitters)
|
|
195
|
+
training_set = TrainingSet.new(slicer, @options.num_slices.to_i)
|
|
196
|
+
|
|
197
|
+
training_set
|
|
198
|
+
end
|
|
199
|
+
|
|
200
|
+
[ method(:training_set), @opts ]
|
|
201
|
+
end
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
#!/bin/env ruby
|
|
2
|
+
|
|
3
|
+
require 'coo-coo'
|
|
4
|
+
|
|
5
|
+
average = Proc.new do |m|
|
|
6
|
+
e = m.each
|
|
7
|
+
Vector[[e.sum / e.count]]
|
|
8
|
+
end
|
|
9
|
+
|
|
10
|
+
xor = Proc.new do |m|
|
|
11
|
+
Vector[[m.to_a.flatten.inject(0) { |acc, n| acc ^ (255.0 * n).to_i } / 256.0]]
|
|
12
|
+
end
|
|
13
|
+
|
|
14
|
+
max = Proc.new do |m|
|
|
15
|
+
Vector[[m.each.max]]
|
|
16
|
+
end
|
|
17
|
+
|
|
18
|
+
def data(n, &block)
|
|
19
|
+
raise ArgumentError.new("Block not given") unless block_given?
|
|
20
|
+
|
|
21
|
+
out = Array.new
|
|
22
|
+
n.times do
|
|
23
|
+
m = Vector.rand(3)
|
|
24
|
+
out << [ block.(m), m ]
|
|
25
|
+
end
|
|
26
|
+
|
|
27
|
+
out
|
|
28
|
+
end
|
|
29
|
+
|
|
30
|
+
def print_prediction(model, input, expecting)
|
|
31
|
+
output = model.forward(input)
|
|
32
|
+
puts("#{input} -> #{output}, expecting #{expecting}, #{expecting - output}")
|
|
33
|
+
end
|
|
34
|
+
|
|
35
|
+
Random.srand(123)
|
|
36
|
+
|
|
37
|
+
f = max
|
|
38
|
+
training_data = data(1000, &f)
|
|
39
|
+
model = CooCoo::Network.new()
|
|
40
|
+
model.layer(CooCoo::Layer.new(3, 8))
|
|
41
|
+
model.layer(CooCoo::Layer.new(8, 8))
|
|
42
|
+
model.layer(CooCoo::Layer.new(8, 1))
|
|
43
|
+
|
|
44
|
+
puts("Training")
|
|
45
|
+
now = Time.now
|
|
46
|
+
model.train(network: model,
|
|
47
|
+
data: training_data,
|
|
48
|
+
learning_rate: 0.3,
|
|
49
|
+
batch_size: 200)
|
|
50
|
+
puts("\tElapsed #{(Time.now - now) / 60.0} min")
|
|
51
|
+
|
|
52
|
+
puts("Predicting:")
|
|
53
|
+
|
|
54
|
+
print_prediction(model, training_data.first[1], training_data.first[0])
|
|
55
|
+
print_prediction(model, Vector[[0.5, 0.75, 0.25]], f.(Vector[[0.5, 0.75, 0.25]]))
|
|
56
|
+
print_prediction(model, Vector[[0.25, 0.0, 0.0]], f.(Vector[[0.25, 0.0, 0.0]]))
|
|
57
|
+
print_prediction(model, Vector[[1.0, 0.0, 0.0]], f.(Vector[[1.0, 0.0, 0.0]]))
|
data/examples/mnist.rb
ADDED
|
@@ -0,0 +1,365 @@
|
|
|
1
|
+
require 'pathname'
|
|
2
|
+
require 'net/http'
|
|
3
|
+
require 'zlib'
|
|
4
|
+
require 'coo-coo/image'
|
|
5
|
+
|
|
6
|
+
module MNist
|
|
7
|
+
PATH = Pathname.new(__FILE__)
|
|
8
|
+
MNIST_URIS = [ "http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz",
|
|
9
|
+
"http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz",
|
|
10
|
+
"http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz",
|
|
11
|
+
"http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz"
|
|
12
|
+
]
|
|
13
|
+
TRAIN_LABELS_PATH = PATH.dirname.join('train-labels-idx1-ubyte')
|
|
14
|
+
TRAIN_IMAGES_PATH = PATH.dirname.join('train-images-idx3-ubyte')
|
|
15
|
+
TEST_LABELS_PATH = PATH.dirname.join('t10k-labels-idx1-ubyte')
|
|
16
|
+
TEST_IMAGES_PATH = PATH.dirname.join('t10k-images-idx3-ubyte')
|
|
17
|
+
|
|
18
|
+
Width = 28
|
|
19
|
+
Height = 28
|
|
20
|
+
|
|
21
|
+
module Fetcher
|
|
22
|
+
def fetch_gzip_url(url)
|
|
23
|
+
data = Net::HTTP.get(url)
|
|
24
|
+
Zlib::GzipReader.new(StringIO.new(data)).read
|
|
25
|
+
end
|
|
26
|
+
|
|
27
|
+
def fetch!
|
|
28
|
+
MNIST_URIS.each do |uri|
|
|
29
|
+
uri = URI.parse(uri)
|
|
30
|
+
path = PATH.dirname.join(File.basename(uri.path).sub(".gz", ""))
|
|
31
|
+
data = fetch_gzip_url(uri)
|
|
32
|
+
File.open(path, "w") do |f|
|
|
33
|
+
f.write(data)
|
|
34
|
+
end
|
|
35
|
+
end
|
|
36
|
+
end
|
|
37
|
+
end
|
|
38
|
+
|
|
39
|
+
class Example
|
|
40
|
+
attr_accessor :label
|
|
41
|
+
attr_accessor :pixels
|
|
42
|
+
attr_accessor :angle
|
|
43
|
+
attr_accessor :offset_x
|
|
44
|
+
attr_accessor :offset_y
|
|
45
|
+
|
|
46
|
+
def initialize(label, pixels, angle = 0, offset_x = 0, offset_y = 0)
|
|
47
|
+
@label = label
|
|
48
|
+
@pixels = pixels
|
|
49
|
+
@angle = angle
|
|
50
|
+
@offset_x = offset_x
|
|
51
|
+
@offset_y = offset_y
|
|
52
|
+
end
|
|
53
|
+
|
|
54
|
+
def pixel(x, y)
|
|
55
|
+
@pixels[y * MNist::Width + x] || 0
|
|
56
|
+
end
|
|
57
|
+
|
|
58
|
+
def to_ascii
|
|
59
|
+
s = ""
|
|
60
|
+
28.times do |y|
|
|
61
|
+
28.times do |x|
|
|
62
|
+
s += char_for_pixel(pixel(x, y))
|
|
63
|
+
end
|
|
64
|
+
s += "\n"
|
|
65
|
+
end
|
|
66
|
+
s
|
|
67
|
+
end
|
|
68
|
+
|
|
69
|
+
def each_pixel(&block)
|
|
70
|
+
return to_enum(__method__) unless block_given?
|
|
71
|
+
28.times do |y|
|
|
72
|
+
28.times do |x|
|
|
73
|
+
yield(pixel(x, y))
|
|
74
|
+
end
|
|
75
|
+
end
|
|
76
|
+
end
|
|
77
|
+
|
|
78
|
+
private
|
|
79
|
+
PixelValues = ' -+X#'
|
|
80
|
+
|
|
81
|
+
def char_for_pixel(p)
|
|
82
|
+
PixelValues[(p / 256.0 * PixelValues.length).to_i]
|
|
83
|
+
end
|
|
84
|
+
end
|
|
85
|
+
|
|
86
|
+
class DataStreamer
|
|
87
|
+
def initialize(labels_path, images_path)
|
|
88
|
+
@labels, @size = open_labels(labels_path)
|
|
89
|
+
@images, @image_size = open_images(images_path)
|
|
90
|
+
end
|
|
91
|
+
|
|
92
|
+
def close
|
|
93
|
+
@labels.close
|
|
94
|
+
@images.close
|
|
95
|
+
end
|
|
96
|
+
|
|
97
|
+
def size
|
|
98
|
+
@size
|
|
99
|
+
end
|
|
100
|
+
|
|
101
|
+
def next
|
|
102
|
+
label = next_label
|
|
103
|
+
pixels = next_image
|
|
104
|
+
if label && pixels
|
|
105
|
+
Example.new(label, pixels)
|
|
106
|
+
end
|
|
107
|
+
end
|
|
108
|
+
|
|
109
|
+
private
|
|
110
|
+
|
|
111
|
+
def open_labels(path)
|
|
112
|
+
f = File.open(path, "rb")
|
|
113
|
+
magic, number = f.read(4 * 2).unpack('NN')
|
|
114
|
+
raise RuntimeError.new("Invalid magic number #{magic} in #{path}") if magic != 0x801
|
|
115
|
+
|
|
116
|
+
[ f, number ]
|
|
117
|
+
end
|
|
118
|
+
|
|
119
|
+
def next_label
|
|
120
|
+
l = @labels.read(1)
|
|
121
|
+
if l
|
|
122
|
+
l.unpack('C').first
|
|
123
|
+
else
|
|
124
|
+
nil
|
|
125
|
+
end
|
|
126
|
+
end
|
|
127
|
+
|
|
128
|
+
def open_images(path)
|
|
129
|
+
f = File.open(path, "rb")
|
|
130
|
+
magic, num_images, height, width = f.read(4 * 4).unpack('NNNN')
|
|
131
|
+
raise RuntimeError.new("Invalid magic number #{magic} in #{path}") if magic != 0x803
|
|
132
|
+
|
|
133
|
+
[ f, width * height * 1 ]
|
|
134
|
+
end
|
|
135
|
+
|
|
136
|
+
def next_image
|
|
137
|
+
p = @images.read(@image_size)
|
|
138
|
+
if p
|
|
139
|
+
p.unpack('C' * @image_size)
|
|
140
|
+
else
|
|
141
|
+
nil
|
|
142
|
+
end
|
|
143
|
+
end
|
|
144
|
+
end
|
|
145
|
+
|
|
146
|
+
class DataStream
|
|
147
|
+
def initialize(labels_path = TRAIN_LABELS_PATH, images_path = TRAIN_IMAGES_PATH)
|
|
148
|
+
if (labels_path == TRAIN_LABELS_PATH && images_path == TRAIN_IMAGES_PATH) ||
|
|
149
|
+
(labels_path == TEST_LABELS_PATH && images_path == TEST_IMAGES_PATH)
|
|
150
|
+
if !File.exists?(labels_path) || !File.exists?(images_path)
|
|
151
|
+
Fetcher.fetch!
|
|
152
|
+
end
|
|
153
|
+
end
|
|
154
|
+
|
|
155
|
+
raise ArgumentError.new("File does not exist: #{labels_path}") unless File.exists?(labels_path)
|
|
156
|
+
raise ArgumentError.new("File does not exist: #{images_path}") unless File.exists?(images_path)
|
|
157
|
+
|
|
158
|
+
@labels_path = labels_path
|
|
159
|
+
@images_path = images_path
|
|
160
|
+
|
|
161
|
+
read_metadata
|
|
162
|
+
end
|
|
163
|
+
|
|
164
|
+
attr_reader :size
|
|
165
|
+
attr_reader :width
|
|
166
|
+
attr_reader :height
|
|
167
|
+
|
|
168
|
+
def each(&block)
|
|
169
|
+
return enum_for(__method__) unless block_given?
|
|
170
|
+
|
|
171
|
+
begin
|
|
172
|
+
streamer = DataStreamer.new(@labels_path, @images_path)
|
|
173
|
+
|
|
174
|
+
begin
|
|
175
|
+
ex = streamer.next
|
|
176
|
+
if ex
|
|
177
|
+
block.call(ex)
|
|
178
|
+
end
|
|
179
|
+
end until ex == nil
|
|
180
|
+
ensure
|
|
181
|
+
streamer.close
|
|
182
|
+
end
|
|
183
|
+
end
|
|
184
|
+
|
|
185
|
+
def to_enum
|
|
186
|
+
each
|
|
187
|
+
end
|
|
188
|
+
|
|
189
|
+
private
|
|
190
|
+
def read_metadata
|
|
191
|
+
read_size
|
|
192
|
+
read_dimensions
|
|
193
|
+
end
|
|
194
|
+
|
|
195
|
+
def read_dimensions
|
|
196
|
+
File.open(@images_path, "rb") do |f|
|
|
197
|
+
magic, num_images, height, width = f.read(4 * 4).unpack('NNNN')
|
|
198
|
+
raise RuntimeError.new("Invalid magic number #{magic} in #{path}") if magic != 0x803
|
|
199
|
+
|
|
200
|
+
@width = width
|
|
201
|
+
@height = height
|
|
202
|
+
end
|
|
203
|
+
end
|
|
204
|
+
|
|
205
|
+
def read_size
|
|
206
|
+
File.open(@labels_path, "rb") do |f|
|
|
207
|
+
magic, number = f.read(4 * 2).unpack('NN')
|
|
208
|
+
raise RuntimeError.new("Invalid magic number #{magic} in #{@labels_path}") if magic != 0x801
|
|
209
|
+
|
|
210
|
+
@size = number
|
|
211
|
+
end
|
|
212
|
+
end
|
|
213
|
+
|
|
214
|
+
public
|
|
215
|
+
class Rotator < Enumerator
|
|
216
|
+
def initialize(data, rotations, rotation_range, random = false)
|
|
217
|
+
@data = data.to_enum
|
|
218
|
+
@rotations = rotations
|
|
219
|
+
@rotation_range = rotation_range
|
|
220
|
+
@random = random
|
|
221
|
+
|
|
222
|
+
super() do |y|
|
|
223
|
+
loop do
|
|
224
|
+
example = @data.next
|
|
225
|
+
@rotations.times do |r|
|
|
226
|
+
t = if @random
|
|
227
|
+
rand
|
|
228
|
+
else
|
|
229
|
+
(r / @rotations.to_f)
|
|
230
|
+
end
|
|
231
|
+
theta = t * @rotation_range - @rotation_range / 2.0
|
|
232
|
+
img = rotate_pixels(example.pixels, theta)
|
|
233
|
+
y << Example.new(example.label, img.to_a.flatten, theta)
|
|
234
|
+
end
|
|
235
|
+
end
|
|
236
|
+
end
|
|
237
|
+
end
|
|
238
|
+
|
|
239
|
+
def wrap(enum)
|
|
240
|
+
self.class.new(enum, @rotations, @rotation_range, @random)
|
|
241
|
+
end
|
|
242
|
+
|
|
243
|
+
def drop(n)
|
|
244
|
+
wrap(@data.drop(n))
|
|
245
|
+
end
|
|
246
|
+
|
|
247
|
+
def rotate_pixels(pixels, theta)
|
|
248
|
+
rot = CooCoo::Image::Rotate.new(MNist::Width / 2, MNist::Height / 2, theta)
|
|
249
|
+
img = CooCoo::Image::Base.new(MNist::Width, MNist::Height, 1, pixels.to_a.flatten)
|
|
250
|
+
(img * rot)
|
|
251
|
+
end
|
|
252
|
+
end
|
|
253
|
+
|
|
254
|
+
class Translator < Enumerator
|
|
255
|
+
def initialize(data, num_translations, dx, dy, random = false)
|
|
256
|
+
@data = data.to_enum
|
|
257
|
+
@num_translations = num_translations
|
|
258
|
+
@dx = dx
|
|
259
|
+
@dy = dy
|
|
260
|
+
@random = random
|
|
261
|
+
|
|
262
|
+
super() do |yielder|
|
|
263
|
+
loop do
|
|
264
|
+
example = @data.next
|
|
265
|
+
@num_translations.times do |r|
|
|
266
|
+
x = if @random
|
|
267
|
+
rand
|
|
268
|
+
else
|
|
269
|
+
(r / @num_translations.to_f)
|
|
270
|
+
end
|
|
271
|
+
x = x * @dx - @dx / 2.0
|
|
272
|
+
y = if @random
|
|
273
|
+
rand
|
|
274
|
+
else
|
|
275
|
+
(r / @num_translations.to_f)
|
|
276
|
+
end
|
|
277
|
+
y = y * @dy - @dy / 2.0
|
|
278
|
+
img = translate_pixels(example.pixels, x, y)
|
|
279
|
+
yielder << Example.new(example.label, img.to_a.flatten, example.angle, x, y)
|
|
280
|
+
end
|
|
281
|
+
end
|
|
282
|
+
end
|
|
283
|
+
end
|
|
284
|
+
|
|
285
|
+
def wrap(enum)
|
|
286
|
+
self.class.new(enum, @num_translations, @dx, @dy, @random)
|
|
287
|
+
end
|
|
288
|
+
|
|
289
|
+
def drop(n)
|
|
290
|
+
wrap(@data.drop(n))
|
|
291
|
+
end
|
|
292
|
+
|
|
293
|
+
def translate_pixels(pixels, x, y)
|
|
294
|
+
transform = CooCoo::Image::Translate.new(x, y)
|
|
295
|
+
img = CooCoo::Image::Base.new(MNist::Width, MNist::Height, 1, pixels.to_a.flatten)
|
|
296
|
+
(img * transform)
|
|
297
|
+
end
|
|
298
|
+
end
|
|
299
|
+
end
|
|
300
|
+
|
|
301
|
+
class TrainingSet
|
|
302
|
+
def initialize(data_stream = MNist::DataStream.new(MNist::TRAIN_LABELS_PATH, MNist::TRAIN_IMAGES_PATH))
|
|
303
|
+
@stream = data_stream
|
|
304
|
+
end
|
|
305
|
+
|
|
306
|
+
def each(&block)
|
|
307
|
+
return to_enum(__method__) unless block
|
|
308
|
+
|
|
309
|
+
enum = @stream.each
|
|
310
|
+
loop do
|
|
311
|
+
example = enum.next
|
|
312
|
+
|
|
313
|
+
a = Array.new(10, 0.0)
|
|
314
|
+
a[example.label] = 1.0
|
|
315
|
+
m = [ CooCoo::Vector[a],
|
|
316
|
+
CooCoo::Vector[example.pixels] / 256.0
|
|
317
|
+
]
|
|
318
|
+
#$stderr.puts("#{m[0]}\t#{m[1]}")
|
|
319
|
+
block.call(m)
|
|
320
|
+
end
|
|
321
|
+
end
|
|
322
|
+
end
|
|
323
|
+
|
|
324
|
+
class DataSet
|
|
325
|
+
attr_reader :examples
|
|
326
|
+
|
|
327
|
+
def initialize
|
|
328
|
+
@examples = Array.new
|
|
329
|
+
end
|
|
330
|
+
|
|
331
|
+
def load!(labels_path, images_path)
|
|
332
|
+
@examples = DataStream.new(labels_path, images_path).each.to_a
|
|
333
|
+
self
|
|
334
|
+
end
|
|
335
|
+
end
|
|
336
|
+
|
|
337
|
+
end
|
|
338
|
+
|
|
339
|
+
if __FILE__ == $0
|
|
340
|
+
def print_example(ex)
|
|
341
|
+
puts(ex.label)
|
|
342
|
+
puts(ex.to_ascii)
|
|
343
|
+
end
|
|
344
|
+
|
|
345
|
+
data = MNist::DataStream.new(MNist::TRAIN_LABELS_PATH, MNist::TRAIN_IMAGES_PATH)
|
|
346
|
+
i = 0
|
|
347
|
+
data.each.
|
|
348
|
+
group_by(&:label).
|
|
349
|
+
collect { |label, values| [ label, values.first ] }.
|
|
350
|
+
sort_by(&:first).
|
|
351
|
+
first(10).
|
|
352
|
+
each do |(label, e)|
|
|
353
|
+
puts(i)
|
|
354
|
+
print_example(e)
|
|
355
|
+
i += 1
|
|
356
|
+
break if i > 20
|
|
357
|
+
end
|
|
358
|
+
|
|
359
|
+
rot = MNist::DataStream::Rotator.new(data.each, 10, Math::PI, false)
|
|
360
|
+
rot.drop(10).first(10).each do |example|
|
|
361
|
+
print_example(example)
|
|
362
|
+
end
|
|
363
|
+
|
|
364
|
+
puts("#{data.size} total #{data.width}x#{data.height} images")
|
|
365
|
+
end
|