ruby-dnn 0.10.1 → 0.10.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +2 -2
- data/examples/cifar100_example.rb +71 -71
- data/examples/cifar10_example.rb +68 -68
- data/examples/iris_example.rb +34 -34
- data/examples/mnist_conv2d_example.rb +50 -50
- data/examples/mnist_example.rb +39 -39
- data/examples/mnist_lstm_example.rb +36 -36
- data/examples/xor_example.rb +24 -24
- data/lib/dnn.rb +27 -26
- data/lib/dnn/cifar10.rb +51 -51
- data/lib/dnn/cifar100.rb +49 -49
- data/lib/dnn/core/activations.rb +148 -148
- data/lib/dnn/core/cnn_layers.rb +464 -464
- data/lib/dnn/core/dataset.rb +34 -34
- data/lib/dnn/core/embedding.rb +56 -0
- data/lib/dnn/core/error.rb +5 -5
- data/lib/dnn/core/initializers.rb +126 -126
- data/lib/dnn/core/layers.rb +307 -307
- data/lib/dnn/core/losses.rb +175 -175
- data/lib/dnn/core/model.rb +461 -461
- data/lib/dnn/core/normalizations.rb +72 -72
- data/lib/dnn/core/optimizers.rb +283 -283
- data/lib/dnn/core/param.rb +9 -9
- data/lib/dnn/core/regularizers.rb +106 -106
- data/lib/dnn/core/rnn_layers.rb +464 -464
- data/lib/dnn/core/utils.rb +34 -34
- data/lib/dnn/downloader.rb +50 -50
- data/lib/dnn/image.rb +41 -41
- data/lib/dnn/iris.rb +60 -60
- data/lib/dnn/mnist.rb +84 -84
- data/lib/dnn/version.rb +3 -3
- metadata +2 -1
data/lib/dnn/core/utils.rb
CHANGED
@@ -1,34 +1,34 @@
|
|
1
|
-
module DNN
|
2
|
-
# This module provides utility functions.
|
3
|
-
module Utils
|
4
|
-
# Categorize labels into "num_classes" classes.
|
5
|
-
def self.to_categorical(y, num_classes, narray_type = nil)
|
6
|
-
narray_type ||= y.class
|
7
|
-
y2 = narray_type.zeros(y.shape[0], num_classes)
|
8
|
-
y.shape[0].times do |i|
|
9
|
-
y2[i, y[i]] = 1
|
10
|
-
end
|
11
|
-
y2
|
12
|
-
end
|
13
|
-
|
14
|
-
# Convert hash to an object.
|
15
|
-
def self.from_hash(hash)
|
16
|
-
return nil if hash == nil
|
17
|
-
dnn_class = DNN.const_get(hash[:class])
|
18
|
-
if dnn_class.respond_to?(:from_hash)
|
19
|
-
return dnn_class.from_hash(hash)
|
20
|
-
end
|
21
|
-
dnn_class.new
|
22
|
-
end
|
23
|
-
|
24
|
-
# Return the result of the sigmoid function.
|
25
|
-
def self.sigmoid(x)
|
26
|
-
Sigmoid.new.forward(x)
|
27
|
-
end
|
28
|
-
|
29
|
-
# Return the result of the softmax function.
|
30
|
-
def self.softmax(x)
|
31
|
-
SoftmaxCrossEntropy.softmax(x)
|
32
|
-
end
|
33
|
-
end
|
34
|
-
end
|
1
|
+
module DNN
|
2
|
+
# This module provides utility functions.
|
3
|
+
module Utils
|
4
|
+
# Categorize labels into "num_classes" classes.
|
5
|
+
def self.to_categorical(y, num_classes, narray_type = nil)
|
6
|
+
narray_type ||= y.class
|
7
|
+
y2 = narray_type.zeros(y.shape[0], num_classes)
|
8
|
+
y.shape[0].times do |i|
|
9
|
+
y2[i, y[i]] = 1
|
10
|
+
end
|
11
|
+
y2
|
12
|
+
end
|
13
|
+
|
14
|
+
# Convert hash to an object.
|
15
|
+
def self.from_hash(hash)
|
16
|
+
return nil if hash == nil
|
17
|
+
dnn_class = DNN.const_get(hash[:class])
|
18
|
+
if dnn_class.respond_to?(:from_hash)
|
19
|
+
return dnn_class.from_hash(hash)
|
20
|
+
end
|
21
|
+
dnn_class.new
|
22
|
+
end
|
23
|
+
|
24
|
+
# Return the result of the sigmoid function.
|
25
|
+
def self.sigmoid(x)
|
26
|
+
Sigmoid.new.forward(x)
|
27
|
+
end
|
28
|
+
|
29
|
+
# Return the result of the softmax function.
|
30
|
+
def self.softmax(x)
|
31
|
+
SoftmaxCrossEntropy.softmax(x)
|
32
|
+
end
|
33
|
+
end
|
34
|
+
end
|
data/lib/dnn/downloader.rb
CHANGED
@@ -1,50 +1,50 @@
|
|
1
|
-
require "net/http"
|
2
|
-
|
3
|
-
module DNN
|
4
|
-
|
5
|
-
class DNN_DownloadError < DNN_Error; end
|
6
|
-
|
7
|
-
class Downloader
|
8
|
-
def self.download(url, dir_path = nil)
|
9
|
-
unless dir_path
|
10
|
-
Dir.mkdir("#{__dir__}/downloads") unless Dir.exist?("#{__dir__}/downloads")
|
11
|
-
dir_path = "#{__dir__}/downloads"
|
12
|
-
end
|
13
|
-
Downloader.new(url).download(dir_path)
|
14
|
-
rescue => ex
|
15
|
-
raise DNN_DownloadError.new(ex.message)
|
16
|
-
end
|
17
|
-
|
18
|
-
def initialize(url)
|
19
|
-
@url = url
|
20
|
-
*, @fqdn, @path = *url.match(%r`https?://(.+?)(/.+)`)
|
21
|
-
end
|
22
|
-
|
23
|
-
def download(dir_path)
|
24
|
-
puts %`download "#{@url}"`
|
25
|
-
buf = ""
|
26
|
-
Net::HTTP.start(@fqdn) do |http|
|
27
|
-
content_length = http.head(@path).content_length
|
28
|
-
http.get(@path) do |body_segment|
|
29
|
-
buf << body_segment
|
30
|
-
log = "\r"
|
31
|
-
40.times do |i|
|
32
|
-
if i < buf.size * 40 / content_length
|
33
|
-
log << "="
|
34
|
-
elsif i == buf.size * 40 / content_length
|
35
|
-
log << ">"
|
36
|
-
else
|
37
|
-
log << "_"
|
38
|
-
end
|
39
|
-
end
|
40
|
-
log << " #{buf.size}/#{content_length}"
|
41
|
-
print log
|
42
|
-
end
|
43
|
-
puts ""
|
44
|
-
end
|
45
|
-
file_name = @path.match(%r`.+/(.+)`)[1]
|
46
|
-
File.binwrite("#{dir_path}/#{file_name}", buf)
|
47
|
-
end
|
48
|
-
end
|
49
|
-
|
50
|
-
end
|
1
|
+
require "net/http"
|
2
|
+
|
3
|
+
module DNN
|
4
|
+
|
5
|
+
class DNN_DownloadError < DNN_Error; end
|
6
|
+
|
7
|
+
class Downloader
|
8
|
+
def self.download(url, dir_path = nil)
|
9
|
+
unless dir_path
|
10
|
+
Dir.mkdir("#{__dir__}/downloads") unless Dir.exist?("#{__dir__}/downloads")
|
11
|
+
dir_path = "#{__dir__}/downloads"
|
12
|
+
end
|
13
|
+
Downloader.new(url).download(dir_path)
|
14
|
+
rescue => ex
|
15
|
+
raise DNN_DownloadError.new(ex.message)
|
16
|
+
end
|
17
|
+
|
18
|
+
def initialize(url)
|
19
|
+
@url = url
|
20
|
+
*, @fqdn, @path = *url.match(%r`https?://(.+?)(/.+)`)
|
21
|
+
end
|
22
|
+
|
23
|
+
def download(dir_path)
|
24
|
+
puts %`download "#{@url}"`
|
25
|
+
buf = ""
|
26
|
+
Net::HTTP.start(@fqdn) do |http|
|
27
|
+
content_length = http.head(@path).content_length
|
28
|
+
http.get(@path) do |body_segment|
|
29
|
+
buf << body_segment
|
30
|
+
log = "\r"
|
31
|
+
40.times do |i|
|
32
|
+
if i < buf.size * 40 / content_length
|
33
|
+
log << "="
|
34
|
+
elsif i == buf.size * 40 / content_length
|
35
|
+
log << ">"
|
36
|
+
else
|
37
|
+
log << "_"
|
38
|
+
end
|
39
|
+
end
|
40
|
+
log << " #{buf.size}/#{content_length}"
|
41
|
+
print log
|
42
|
+
end
|
43
|
+
puts ""
|
44
|
+
end
|
45
|
+
file_name = @path.match(%r`.+/(.+)`)[1]
|
46
|
+
File.binwrite("#{dir_path}/#{file_name}", buf)
|
47
|
+
end
|
48
|
+
end
|
49
|
+
|
50
|
+
end
|
data/lib/dnn/image.rb
CHANGED
@@ -1,41 +1,41 @@
|
|
1
|
-
require "numo/narray"
|
2
|
-
require_relative "../../ext/rb_stb_image/rb_stb_image"
|
3
|
-
|
4
|
-
module DNN
|
5
|
-
module Image
|
6
|
-
def self.read(file_name)
|
7
|
-
raise Image::ReadError.new("#{file_name} is not found.") unless File.exist?(file_name)
|
8
|
-
bin, w, h, n = Stb.stbi_load(file_name, 3)
|
9
|
-
img = Numo::UInt8.from_binary(bin)
|
10
|
-
img.reshape(h, w, 3)
|
11
|
-
end
|
12
|
-
|
13
|
-
def self.write(file_name, img, quality: 100)
|
14
|
-
if img.shape.length == 2
|
15
|
-
img = Numo::UInt8[img, img, img].transpose(1, 2, 0).clone
|
16
|
-
elsif img.shape[2] == 1
|
17
|
-
img = img.reshape(img.shape[0], img.shape[1])
|
18
|
-
img = Numo::UInt8[img, img, img].transpose(1, 2, 0).clone
|
19
|
-
end
|
20
|
-
h, w, ch = img.shape
|
21
|
-
bin = img.to_binary
|
22
|
-
case file_name
|
23
|
-
when /\.png$/i
|
24
|
-
stride_in_bytes = w * ch
|
25
|
-
Stb.stbi_write_png(file_name, w, h, ch, bin, stride_in_bytes)
|
26
|
-
when /\.bmp$/i
|
27
|
-
Stb.stbi_write_bmp(file_name, w, h, ch, bin)
|
28
|
-
when /\.jpg$/i, /\.jpeg/i
|
29
|
-
Stb.stbi_write_jpg(file_name, w, h, ch, bin, quality)
|
30
|
-
end
|
31
|
-
rescue => ex
|
32
|
-
raise Image::WriteError.new(ex.message)
|
33
|
-
end
|
34
|
-
end
|
35
|
-
|
36
|
-
class Image::Error < StandardError; end
|
37
|
-
|
38
|
-
class Image::ReadError < Image::Error; end
|
39
|
-
|
40
|
-
class Image::WriteError < Image::Error; end
|
41
|
-
end
|
1
|
+
require "numo/narray"
|
2
|
+
require_relative "../../ext/rb_stb_image/rb_stb_image"
|
3
|
+
|
4
|
+
module DNN
|
5
|
+
module Image
|
6
|
+
def self.read(file_name)
|
7
|
+
raise Image::ReadError.new("#{file_name} is not found.") unless File.exist?(file_name)
|
8
|
+
bin, w, h, n = Stb.stbi_load(file_name, 3)
|
9
|
+
img = Numo::UInt8.from_binary(bin)
|
10
|
+
img.reshape(h, w, 3)
|
11
|
+
end
|
12
|
+
|
13
|
+
def self.write(file_name, img, quality: 100)
|
14
|
+
if img.shape.length == 2
|
15
|
+
img = Numo::UInt8[img, img, img].transpose(1, 2, 0).clone
|
16
|
+
elsif img.shape[2] == 1
|
17
|
+
img = img.reshape(img.shape[0], img.shape[1])
|
18
|
+
img = Numo::UInt8[img, img, img].transpose(1, 2, 0).clone
|
19
|
+
end
|
20
|
+
h, w, ch = img.shape
|
21
|
+
bin = img.to_binary
|
22
|
+
case file_name
|
23
|
+
when /\.png$/i
|
24
|
+
stride_in_bytes = w * ch
|
25
|
+
Stb.stbi_write_png(file_name, w, h, ch, bin, stride_in_bytes)
|
26
|
+
when /\.bmp$/i
|
27
|
+
Stb.stbi_write_bmp(file_name, w, h, ch, bin)
|
28
|
+
when /\.jpg$/i, /\.jpeg/i
|
29
|
+
Stb.stbi_write_jpg(file_name, w, h, ch, bin, quality)
|
30
|
+
end
|
31
|
+
rescue => ex
|
32
|
+
raise Image::WriteError.new(ex.message)
|
33
|
+
end
|
34
|
+
end
|
35
|
+
|
36
|
+
class Image::Error < StandardError; end
|
37
|
+
|
38
|
+
class Image::ReadError < Image::Error; end
|
39
|
+
|
40
|
+
class Image::WriteError < Image::Error; end
|
41
|
+
end
|
data/lib/dnn/iris.rb
CHANGED
@@ -1,60 +1,60 @@
|
|
1
|
-
require "csv"
|
2
|
-
require_relative "downloader"
|
3
|
-
|
4
|
-
module DNN
|
5
|
-
class DNN_Iris_LoadError < DNN_Error; end
|
6
|
-
|
7
|
-
module Iris
|
8
|
-
URL_CSV = "https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data"
|
9
|
-
|
10
|
-
# Iris-setosa
|
11
|
-
SETOSA = 0
|
12
|
-
# Iris-versicolor
|
13
|
-
VERSICOLOR = 1
|
14
|
-
# Iris-virginica
|
15
|
-
VIRGINICA = 2
|
16
|
-
|
17
|
-
def self.downloads
|
18
|
-
return if File.exist?(url_to_file_name(URL_CSV))
|
19
|
-
Downloader.download(URL_CSV)
|
20
|
-
end
|
21
|
-
|
22
|
-
def self.load(shuffle = false, shuffle_seed = rand(1 << 31))
|
23
|
-
downloads
|
24
|
-
csv_array = CSV.read(url_to_file_name(URL_CSV)).select { |a| a.length > 0 }
|
25
|
-
x = Numo::SFloat.zeros(csv_array.length, 4)
|
26
|
-
y = Numo::SFloat.zeros(csv_array.length)
|
27
|
-
csv_array.each.with_index do |(sepal_length, sepal_width, petal_length, petal_width, classes), i|
|
28
|
-
x[i, 0] = sepal_length.to_f
|
29
|
-
x[i, 1] = sepal_width.to_f
|
30
|
-
x[i, 2] = petal_length.to_f
|
31
|
-
x[i, 3] = petal_width.to_f
|
32
|
-
y[i] = case classes
|
33
|
-
when "Iris-setosa"
|
34
|
-
SETOSA
|
35
|
-
when "Iris-versicolor"
|
36
|
-
VERSICOLOR
|
37
|
-
when "Iris-virginica"
|
38
|
-
VIRGINICA
|
39
|
-
else
|
40
|
-
raise DNN_Iris_LoadError.new("Unknown class name '#{classes}' for iris")
|
41
|
-
end
|
42
|
-
end
|
43
|
-
if shuffle
|
44
|
-
orig_seed = Random::DEFAULT.seed
|
45
|
-
srand(shuffle_seed)
|
46
|
-
indexs = (0...csv_array.length).to_a.shuffle
|
47
|
-
x[indexs, true] = x
|
48
|
-
y[indexs] = y
|
49
|
-
srand(orig_seed)
|
50
|
-
end
|
51
|
-
[x, y]
|
52
|
-
end
|
53
|
-
|
54
|
-
private_class_method
|
55
|
-
|
56
|
-
def self.url_to_file_name(url)
|
57
|
-
__dir__ + "/downloads/" + url.match(%r`.+/(.+)$`)[1]
|
58
|
-
end
|
59
|
-
end
|
60
|
-
end
|
1
|
+
require "csv"
|
2
|
+
require_relative "downloader"
|
3
|
+
|
4
|
+
module DNN
|
5
|
+
class DNN_Iris_LoadError < DNN_Error; end
|
6
|
+
|
7
|
+
module Iris
|
8
|
+
URL_CSV = "https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data"
|
9
|
+
|
10
|
+
# Iris-setosa
|
11
|
+
SETOSA = 0
|
12
|
+
# Iris-versicolor
|
13
|
+
VERSICOLOR = 1
|
14
|
+
# Iris-virginica
|
15
|
+
VIRGINICA = 2
|
16
|
+
|
17
|
+
def self.downloads
|
18
|
+
return if File.exist?(url_to_file_name(URL_CSV))
|
19
|
+
Downloader.download(URL_CSV)
|
20
|
+
end
|
21
|
+
|
22
|
+
def self.load(shuffle = false, shuffle_seed = rand(1 << 31))
|
23
|
+
downloads
|
24
|
+
csv_array = CSV.read(url_to_file_name(URL_CSV)).select { |a| a.length > 0 }
|
25
|
+
x = Numo::SFloat.zeros(csv_array.length, 4)
|
26
|
+
y = Numo::SFloat.zeros(csv_array.length)
|
27
|
+
csv_array.each.with_index do |(sepal_length, sepal_width, petal_length, petal_width, classes), i|
|
28
|
+
x[i, 0] = sepal_length.to_f
|
29
|
+
x[i, 1] = sepal_width.to_f
|
30
|
+
x[i, 2] = petal_length.to_f
|
31
|
+
x[i, 3] = petal_width.to_f
|
32
|
+
y[i] = case classes
|
33
|
+
when "Iris-setosa"
|
34
|
+
SETOSA
|
35
|
+
when "Iris-versicolor"
|
36
|
+
VERSICOLOR
|
37
|
+
when "Iris-virginica"
|
38
|
+
VIRGINICA
|
39
|
+
else
|
40
|
+
raise DNN_Iris_LoadError.new("Unknown class name '#{classes}' for iris")
|
41
|
+
end
|
42
|
+
end
|
43
|
+
if shuffle
|
44
|
+
orig_seed = Random::DEFAULT.seed
|
45
|
+
srand(shuffle_seed)
|
46
|
+
indexs = (0...csv_array.length).to_a.shuffle
|
47
|
+
x[indexs, true] = x
|
48
|
+
y[indexs] = y
|
49
|
+
srand(orig_seed)
|
50
|
+
end
|
51
|
+
[x, y]
|
52
|
+
end
|
53
|
+
|
54
|
+
private_class_method
|
55
|
+
|
56
|
+
def self.url_to_file_name(url)
|
57
|
+
__dir__ + "/downloads/" + url.match(%r`.+/(.+)$`)[1]
|
58
|
+
end
|
59
|
+
end
|
60
|
+
end
|
data/lib/dnn/mnist.rb
CHANGED
@@ -1,84 +1,84 @@
|
|
1
|
-
require "zlib"
|
2
|
-
require_relative "core/error"
|
3
|
-
require_relative "downloader"
|
4
|
-
|
5
|
-
module DNN
|
6
|
-
module MNIST
|
7
|
-
class DNN_MNIST_LoadError < DNN_Error; end
|
8
|
-
|
9
|
-
URL_TRAIN_IMAGES = "http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz"
|
10
|
-
URL_TRAIN_LABELS = "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz"
|
11
|
-
URL_TEST_IMAGES = "http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz"
|
12
|
-
URL_TEST_LABELS = "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz"
|
13
|
-
|
14
|
-
def self.downloads
|
15
|
-
return if Dir.exist?(mnist_dir)
|
16
|
-
Dir.mkdir("#{__dir__}/downloads") unless Dir.exist?("#{__dir__}/downloads")
|
17
|
-
Dir.mkdir(mnist_dir)
|
18
|
-
Downloader.download(URL_TRAIN_IMAGES, mnist_dir)
|
19
|
-
Downloader.download(URL_TRAIN_LABELS, mnist_dir)
|
20
|
-
Downloader.download(URL_TEST_IMAGES, mnist_dir)
|
21
|
-
Downloader.download(URL_TEST_LABELS, mnist_dir)
|
22
|
-
end
|
23
|
-
|
24
|
-
def self.load_train
|
25
|
-
downloads
|
26
|
-
train_images_file_name = url_to_file_name(URL_TRAIN_IMAGES)
|
27
|
-
train_labels_file_name = url_to_file_name(URL_TRAIN_LABELS)
|
28
|
-
unless File.exist?(train_images_file_name)
|
29
|
-
raise DNN_MNIST_LoadError.new(%`file "#{train_images_file_name}" is not found.`)
|
30
|
-
end
|
31
|
-
unless File.exist?(train_labels_file_name)
|
32
|
-
raise DNN_MNIST_LoadError.new(%`file "#{train_labels_file_name}" is not found.`)
|
33
|
-
end
|
34
|
-
images = load_images(train_images_file_name)
|
35
|
-
labels = load_labels(train_labels_file_name)
|
36
|
-
[images, labels]
|
37
|
-
end
|
38
|
-
|
39
|
-
def self.load_test
|
40
|
-
downloads
|
41
|
-
test_images_file_name = url_to_file_name(URL_TEST_IMAGES)
|
42
|
-
test_labels_file_name = url_to_file_name(URL_TEST_LABELS)
|
43
|
-
unless File.exist?(test_images_file_name)
|
44
|
-
raise DNN_MNIST_LoadError.new(%`file "#{train_images_file_name}" is not found.`)
|
45
|
-
end
|
46
|
-
unless File.exist?(test_labels_file_name)
|
47
|
-
raise DNN_MNIST_LoadError.new(%`file "#{train_labels_file_name}" is not found.`)
|
48
|
-
end
|
49
|
-
images = load_images(test_images_file_name)
|
50
|
-
labels = load_labels(test_labels_file_name)
|
51
|
-
[images, labels]
|
52
|
-
end
|
53
|
-
|
54
|
-
private_class_method
|
55
|
-
|
56
|
-
def self.load_images(file_name)
|
57
|
-
images = nil
|
58
|
-
Zlib::GzipReader.open(file_name) do |f|
|
59
|
-
magic, num_images = f.read(8).unpack("N2")
|
60
|
-
rows, cols = f.read(8).unpack("N2")
|
61
|
-
images = Numo::UInt8.from_binary(f.read)
|
62
|
-
images = images.reshape(num_images, cols, rows)
|
63
|
-
end
|
64
|
-
images
|
65
|
-
end
|
66
|
-
|
67
|
-
def self.load_labels(file_name)
|
68
|
-
labels = nil
|
69
|
-
Zlib::GzipReader.open(file_name) do |f|
|
70
|
-
magic, num_labels = f.read(8).unpack("N2")
|
71
|
-
labels = Numo::UInt8.from_binary(f.read)
|
72
|
-
end
|
73
|
-
labels
|
74
|
-
end
|
75
|
-
|
76
|
-
def self.mnist_dir
|
77
|
-
"#{__dir__}/downloads/mnist"
|
78
|
-
end
|
79
|
-
|
80
|
-
def self.url_to_file_name(url)
|
81
|
-
mnist_dir + "/" + url.match(%r`.+/(.+)$`)[1]
|
82
|
-
end
|
83
|
-
end
|
84
|
-
end
|
1
|
+
require "zlib"
|
2
|
+
require_relative "core/error"
|
3
|
+
require_relative "downloader"
|
4
|
+
|
5
|
+
module DNN
|
6
|
+
module MNIST
|
7
|
+
class DNN_MNIST_LoadError < DNN_Error; end
|
8
|
+
|
9
|
+
URL_TRAIN_IMAGES = "http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz"
|
10
|
+
URL_TRAIN_LABELS = "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz"
|
11
|
+
URL_TEST_IMAGES = "http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz"
|
12
|
+
URL_TEST_LABELS = "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz"
|
13
|
+
|
14
|
+
def self.downloads
|
15
|
+
return if Dir.exist?(mnist_dir)
|
16
|
+
Dir.mkdir("#{__dir__}/downloads") unless Dir.exist?("#{__dir__}/downloads")
|
17
|
+
Dir.mkdir(mnist_dir)
|
18
|
+
Downloader.download(URL_TRAIN_IMAGES, mnist_dir)
|
19
|
+
Downloader.download(URL_TRAIN_LABELS, mnist_dir)
|
20
|
+
Downloader.download(URL_TEST_IMAGES, mnist_dir)
|
21
|
+
Downloader.download(URL_TEST_LABELS, mnist_dir)
|
22
|
+
end
|
23
|
+
|
24
|
+
def self.load_train
|
25
|
+
downloads
|
26
|
+
train_images_file_name = url_to_file_name(URL_TRAIN_IMAGES)
|
27
|
+
train_labels_file_name = url_to_file_name(URL_TRAIN_LABELS)
|
28
|
+
unless File.exist?(train_images_file_name)
|
29
|
+
raise DNN_MNIST_LoadError.new(%`file "#{train_images_file_name}" is not found.`)
|
30
|
+
end
|
31
|
+
unless File.exist?(train_labels_file_name)
|
32
|
+
raise DNN_MNIST_LoadError.new(%`file "#{train_labels_file_name}" is not found.`)
|
33
|
+
end
|
34
|
+
images = load_images(train_images_file_name)
|
35
|
+
labels = load_labels(train_labels_file_name)
|
36
|
+
[images, labels]
|
37
|
+
end
|
38
|
+
|
39
|
+
def self.load_test
|
40
|
+
downloads
|
41
|
+
test_images_file_name = url_to_file_name(URL_TEST_IMAGES)
|
42
|
+
test_labels_file_name = url_to_file_name(URL_TEST_LABELS)
|
43
|
+
unless File.exist?(test_images_file_name)
|
44
|
+
raise DNN_MNIST_LoadError.new(%`file "#{train_images_file_name}" is not found.`)
|
45
|
+
end
|
46
|
+
unless File.exist?(test_labels_file_name)
|
47
|
+
raise DNN_MNIST_LoadError.new(%`file "#{train_labels_file_name}" is not found.`)
|
48
|
+
end
|
49
|
+
images = load_images(test_images_file_name)
|
50
|
+
labels = load_labels(test_labels_file_name)
|
51
|
+
[images, labels]
|
52
|
+
end
|
53
|
+
|
54
|
+
private_class_method
|
55
|
+
|
56
|
+
def self.load_images(file_name)
|
57
|
+
images = nil
|
58
|
+
Zlib::GzipReader.open(file_name) do |f|
|
59
|
+
magic, num_images = f.read(8).unpack("N2")
|
60
|
+
rows, cols = f.read(8).unpack("N2")
|
61
|
+
images = Numo::UInt8.from_binary(f.read)
|
62
|
+
images = images.reshape(num_images, cols, rows)
|
63
|
+
end
|
64
|
+
images
|
65
|
+
end
|
66
|
+
|
67
|
+
def self.load_labels(file_name)
|
68
|
+
labels = nil
|
69
|
+
Zlib::GzipReader.open(file_name) do |f|
|
70
|
+
magic, num_labels = f.read(8).unpack("N2")
|
71
|
+
labels = Numo::UInt8.from_binary(f.read)
|
72
|
+
end
|
73
|
+
labels
|
74
|
+
end
|
75
|
+
|
76
|
+
def self.mnist_dir
|
77
|
+
"#{__dir__}/downloads/mnist"
|
78
|
+
end
|
79
|
+
|
80
|
+
def self.url_to_file_name(url)
|
81
|
+
mnist_dir + "/" + url.match(%r`.+/(.+)$`)[1]
|
82
|
+
end
|
83
|
+
end
|
84
|
+
end
|