ruby-dnn 0.10.1 → 0.10.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,34 +1,34 @@
1
- module DNN
2
- # This module provides utility functions.
3
- module Utils
4
- # Categorize labels into "num_classes" classes.
5
- def self.to_categorical(y, num_classes, narray_type = nil)
6
- narray_type ||= y.class
7
- y2 = narray_type.zeros(y.shape[0], num_classes)
8
- y.shape[0].times do |i|
9
- y2[i, y[i]] = 1
10
- end
11
- y2
12
- end
13
-
14
- # Convert hash to an object.
15
- def self.from_hash(hash)
16
- return nil if hash == nil
17
- dnn_class = DNN.const_get(hash[:class])
18
- if dnn_class.respond_to?(:from_hash)
19
- return dnn_class.from_hash(hash)
20
- end
21
- dnn_class.new
22
- end
23
-
24
- # Return the result of the sigmoid function.
25
- def self.sigmoid(x)
26
- Sigmoid.new.forward(x)
27
- end
28
-
29
- # Return the result of the softmax function.
30
- def self.softmax(x)
31
- SoftmaxCrossEntropy.softmax(x)
32
- end
33
- end
34
- end
1
+ module DNN
2
+ # This module provides utility functions.
3
+ module Utils
4
+ # Categorize labels into "num_classes" classes.
5
+ def self.to_categorical(y, num_classes, narray_type = nil)
6
+ narray_type ||= y.class
7
+ y2 = narray_type.zeros(y.shape[0], num_classes)
8
+ y.shape[0].times do |i|
9
+ y2[i, y[i]] = 1
10
+ end
11
+ y2
12
+ end
13
+
14
+ # Convert hash to an object.
15
+ def self.from_hash(hash)
16
+ return nil if hash == nil
17
+ dnn_class = DNN.const_get(hash[:class])
18
+ if dnn_class.respond_to?(:from_hash)
19
+ return dnn_class.from_hash(hash)
20
+ end
21
+ dnn_class.new
22
+ end
23
+
24
+ # Return the result of the sigmoid function.
25
+ def self.sigmoid(x)
26
+ Sigmoid.new.forward(x)
27
+ end
28
+
29
+ # Return the result of the softmax function.
30
+ def self.softmax(x)
31
+ SoftmaxCrossEntropy.softmax(x)
32
+ end
33
+ end
34
+ end
@@ -1,50 +1,50 @@
1
- require "net/http"
2
-
3
- module DNN
4
-
5
- class DNN_DownloadError < DNN_Error; end
6
-
7
- class Downloader
8
- def self.download(url, dir_path = nil)
9
- unless dir_path
10
- Dir.mkdir("#{__dir__}/downloads") unless Dir.exist?("#{__dir__}/downloads")
11
- dir_path = "#{__dir__}/downloads"
12
- end
13
- Downloader.new(url).download(dir_path)
14
- rescue => ex
15
- raise DNN_DownloadError.new(ex.message)
16
- end
17
-
18
- def initialize(url)
19
- @url = url
20
- *, @fqdn, @path = *url.match(%r`https?://(.+?)(/.+)`)
21
- end
22
-
23
- def download(dir_path)
24
- puts %`download "#{@url}"`
25
- buf = ""
26
- Net::HTTP.start(@fqdn) do |http|
27
- content_length = http.head(@path).content_length
28
- http.get(@path) do |body_segment|
29
- buf << body_segment
30
- log = "\r"
31
- 40.times do |i|
32
- if i < buf.size * 40 / content_length
33
- log << "="
34
- elsif i == buf.size * 40 / content_length
35
- log << ">"
36
- else
37
- log << "_"
38
- end
39
- end
40
- log << " #{buf.size}/#{content_length}"
41
- print log
42
- end
43
- puts ""
44
- end
45
- file_name = @path.match(%r`.+/(.+)`)[1]
46
- File.binwrite("#{dir_path}/#{file_name}", buf)
47
- end
48
- end
49
-
50
- end
1
+ require "net/http"
2
+
3
+ module DNN
4
+
5
+ class DNN_DownloadError < DNN_Error; end
6
+
7
+ class Downloader
8
+ def self.download(url, dir_path = nil)
9
+ unless dir_path
10
+ Dir.mkdir("#{__dir__}/downloads") unless Dir.exist?("#{__dir__}/downloads")
11
+ dir_path = "#{__dir__}/downloads"
12
+ end
13
+ Downloader.new(url).download(dir_path)
14
+ rescue => ex
15
+ raise DNN_DownloadError.new(ex.message)
16
+ end
17
+
18
+ def initialize(url)
19
+ @url = url
20
+ *, @fqdn, @path = *url.match(%r`https?://(.+?)(/.+)`)
21
+ end
22
+
23
+ def download(dir_path)
24
+ puts %`download "#{@url}"`
25
+ buf = ""
26
+ Net::HTTP.start(@fqdn) do |http|
27
+ content_length = http.head(@path).content_length
28
+ http.get(@path) do |body_segment|
29
+ buf << body_segment
30
+ log = "\r"
31
+ 40.times do |i|
32
+ if i < buf.size * 40 / content_length
33
+ log << "="
34
+ elsif i == buf.size * 40 / content_length
35
+ log << ">"
36
+ else
37
+ log << "_"
38
+ end
39
+ end
40
+ log << " #{buf.size}/#{content_length}"
41
+ print log
42
+ end
43
+ puts ""
44
+ end
45
+ file_name = @path.match(%r`.+/(.+)`)[1]
46
+ File.binwrite("#{dir_path}/#{file_name}", buf)
47
+ end
48
+ end
49
+
50
+ end
@@ -1,41 +1,41 @@
1
- require "numo/narray"
2
- require_relative "../../ext/rb_stb_image/rb_stb_image"
3
-
4
- module DNN
5
- module Image
6
- def self.read(file_name)
7
- raise Image::ReadError.new("#{file_name} is not found.") unless File.exist?(file_name)
8
- bin, w, h, n = Stb.stbi_load(file_name, 3)
9
- img = Numo::UInt8.from_binary(bin)
10
- img.reshape(h, w, 3)
11
- end
12
-
13
- def self.write(file_name, img, quality: 100)
14
- if img.shape.length == 2
15
- img = Numo::UInt8[img, img, img].transpose(1, 2, 0).clone
16
- elsif img.shape[2] == 1
17
- img = img.reshape(img.shape[0], img.shape[1])
18
- img = Numo::UInt8[img, img, img].transpose(1, 2, 0).clone
19
- end
20
- h, w, ch = img.shape
21
- bin = img.to_binary
22
- case file_name
23
- when /\.png$/i
24
- stride_in_bytes = w * ch
25
- Stb.stbi_write_png(file_name, w, h, ch, bin, stride_in_bytes)
26
- when /\.bmp$/i
27
- Stb.stbi_write_bmp(file_name, w, h, ch, bin)
28
- when /\.jpg$/i, /\.jpeg/i
29
- Stb.stbi_write_jpg(file_name, w, h, ch, bin, quality)
30
- end
31
- rescue => ex
32
- raise Image::WriteError.new(ex.message)
33
- end
34
- end
35
-
36
- class Image::Error < StandardError; end
37
-
38
- class Image::ReadError < Image::Error; end
39
-
40
- class Image::WriteError < Image::Error; end
41
- end
1
+ require "numo/narray"
2
+ require_relative "../../ext/rb_stb_image/rb_stb_image"
3
+
4
+ module DNN
5
+ module Image
6
+ def self.read(file_name)
7
+ raise Image::ReadError.new("#{file_name} is not found.") unless File.exist?(file_name)
8
+ bin, w, h, n = Stb.stbi_load(file_name, 3)
9
+ img = Numo::UInt8.from_binary(bin)
10
+ img.reshape(h, w, 3)
11
+ end
12
+
13
+ def self.write(file_name, img, quality: 100)
14
+ if img.shape.length == 2
15
+ img = Numo::UInt8[img, img, img].transpose(1, 2, 0).clone
16
+ elsif img.shape[2] == 1
17
+ img = img.reshape(img.shape[0], img.shape[1])
18
+ img = Numo::UInt8[img, img, img].transpose(1, 2, 0).clone
19
+ end
20
+ h, w, ch = img.shape
21
+ bin = img.to_binary
22
+ case file_name
23
+ when /\.png$/i
24
+ stride_in_bytes = w * ch
25
+ Stb.stbi_write_png(file_name, w, h, ch, bin, stride_in_bytes)
26
+ when /\.bmp$/i
27
+ Stb.stbi_write_bmp(file_name, w, h, ch, bin)
28
+ when /\.jpg$/i, /\.jpeg/i
29
+ Stb.stbi_write_jpg(file_name, w, h, ch, bin, quality)
30
+ end
31
+ rescue => ex
32
+ raise Image::WriteError.new(ex.message)
33
+ end
34
+ end
35
+
36
+ class Image::Error < StandardError; end
37
+
38
+ class Image::ReadError < Image::Error; end
39
+
40
+ class Image::WriteError < Image::Error; end
41
+ end
@@ -1,60 +1,60 @@
1
- require "csv"
2
- require_relative "downloader"
3
-
4
- module DNN
5
- class DNN_Iris_LoadError < DNN_Error; end
6
-
7
- module Iris
8
- URL_CSV = "https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data"
9
-
10
- # Iris-setosa
11
- SETOSA = 0
12
- # Iris-versicolor
13
- VERSICOLOR = 1
14
- # Iris-virginica
15
- VIRGINICA = 2
16
-
17
- def self.downloads
18
- return if File.exist?(url_to_file_name(URL_CSV))
19
- Downloader.download(URL_CSV)
20
- end
21
-
22
- def self.load(shuffle = false, shuffle_seed = rand(1 << 31))
23
- downloads
24
- csv_array = CSV.read(url_to_file_name(URL_CSV)).select { |a| a.length > 0 }
25
- x = Numo::SFloat.zeros(csv_array.length, 4)
26
- y = Numo::SFloat.zeros(csv_array.length)
27
- csv_array.each.with_index do |(sepal_length, sepal_width, petal_length, petal_width, classes), i|
28
- x[i, 0] = sepal_length.to_f
29
- x[i, 1] = sepal_width.to_f
30
- x[i, 2] = petal_length.to_f
31
- x[i, 3] = petal_width.to_f
32
- y[i] = case classes
33
- when "Iris-setosa"
34
- SETOSA
35
- when "Iris-versicolor"
36
- VERSICOLOR
37
- when "Iris-virginica"
38
- VIRGINICA
39
- else
40
- raise DNN_Iris_LoadError.new("Unknown class name '#{classes}' for iris")
41
- end
42
- end
43
- if shuffle
44
- orig_seed = Random::DEFAULT.seed
45
- srand(shuffle_seed)
46
- indexs = (0...csv_array.length).to_a.shuffle
47
- x[indexs, true] = x
48
- y[indexs] = y
49
- srand(orig_seed)
50
- end
51
- [x, y]
52
- end
53
-
54
- private_class_method
55
-
56
- def self.url_to_file_name(url)
57
- __dir__ + "/downloads/" + url.match(%r`.+/(.+)$`)[1]
58
- end
59
- end
60
- end
1
+ require "csv"
2
+ require_relative "downloader"
3
+
4
+ module DNN
5
+ class DNN_Iris_LoadError < DNN_Error; end
6
+
7
+ module Iris
8
+ URL_CSV = "https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data"
9
+
10
+ # Iris-setosa
11
+ SETOSA = 0
12
+ # Iris-versicolor
13
+ VERSICOLOR = 1
14
+ # Iris-virginica
15
+ VIRGINICA = 2
16
+
17
+ def self.downloads
18
+ return if File.exist?(url_to_file_name(URL_CSV))
19
+ Downloader.download(URL_CSV)
20
+ end
21
+
22
+ def self.load(shuffle = false, shuffle_seed = rand(1 << 31))
23
+ downloads
24
+ csv_array = CSV.read(url_to_file_name(URL_CSV)).select { |a| a.length > 0 }
25
+ x = Numo::SFloat.zeros(csv_array.length, 4)
26
+ y = Numo::SFloat.zeros(csv_array.length)
27
+ csv_array.each.with_index do |(sepal_length, sepal_width, petal_length, petal_width, classes), i|
28
+ x[i, 0] = sepal_length.to_f
29
+ x[i, 1] = sepal_width.to_f
30
+ x[i, 2] = petal_length.to_f
31
+ x[i, 3] = petal_width.to_f
32
+ y[i] = case classes
33
+ when "Iris-setosa"
34
+ SETOSA
35
+ when "Iris-versicolor"
36
+ VERSICOLOR
37
+ when "Iris-virginica"
38
+ VIRGINICA
39
+ else
40
+ raise DNN_Iris_LoadError.new("Unknown class name '#{classes}' for iris")
41
+ end
42
+ end
43
+ if shuffle
44
+ orig_seed = Random::DEFAULT.seed
45
+ srand(shuffle_seed)
46
+ indexs = (0...csv_array.length).to_a.shuffle
47
+ x[indexs, true] = x
48
+ y[indexs] = y
49
+ srand(orig_seed)
50
+ end
51
+ [x, y]
52
+ end
53
+
54
+ private_class_method
55
+
56
+ def self.url_to_file_name(url)
57
+ __dir__ + "/downloads/" + url.match(%r`.+/(.+)$`)[1]
58
+ end
59
+ end
60
+ end
@@ -1,84 +1,84 @@
1
- require "zlib"
2
- require_relative "core/error"
3
- require_relative "downloader"
4
-
5
- module DNN
6
- module MNIST
7
- class DNN_MNIST_LoadError < DNN_Error; end
8
-
9
- URL_TRAIN_IMAGES = "http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz"
10
- URL_TRAIN_LABELS = "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz"
11
- URL_TEST_IMAGES = "http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz"
12
- URL_TEST_LABELS = "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz"
13
-
14
- def self.downloads
15
- return if Dir.exist?(mnist_dir)
16
- Dir.mkdir("#{__dir__}/downloads") unless Dir.exist?("#{__dir__}/downloads")
17
- Dir.mkdir(mnist_dir)
18
- Downloader.download(URL_TRAIN_IMAGES, mnist_dir)
19
- Downloader.download(URL_TRAIN_LABELS, mnist_dir)
20
- Downloader.download(URL_TEST_IMAGES, mnist_dir)
21
- Downloader.download(URL_TEST_LABELS, mnist_dir)
22
- end
23
-
24
- def self.load_train
25
- downloads
26
- train_images_file_name = url_to_file_name(URL_TRAIN_IMAGES)
27
- train_labels_file_name = url_to_file_name(URL_TRAIN_LABELS)
28
- unless File.exist?(train_images_file_name)
29
- raise DNN_MNIST_LoadError.new(%`file "#{train_images_file_name}" is not found.`)
30
- end
31
- unless File.exist?(train_labels_file_name)
32
- raise DNN_MNIST_LoadError.new(%`file "#{train_labels_file_name}" is not found.`)
33
- end
34
- images = load_images(train_images_file_name)
35
- labels = load_labels(train_labels_file_name)
36
- [images, labels]
37
- end
38
-
39
- def self.load_test
40
- downloads
41
- test_images_file_name = url_to_file_name(URL_TEST_IMAGES)
42
- test_labels_file_name = url_to_file_name(URL_TEST_LABELS)
43
- unless File.exist?(test_images_file_name)
44
- raise DNN_MNIST_LoadError.new(%`file "#{train_images_file_name}" is not found.`)
45
- end
46
- unless File.exist?(test_labels_file_name)
47
- raise DNN_MNIST_LoadError.new(%`file "#{train_labels_file_name}" is not found.`)
48
- end
49
- images = load_images(test_images_file_name)
50
- labels = load_labels(test_labels_file_name)
51
- [images, labels]
52
- end
53
-
54
- private_class_method
55
-
56
- def self.load_images(file_name)
57
- images = nil
58
- Zlib::GzipReader.open(file_name) do |f|
59
- magic, num_images = f.read(8).unpack("N2")
60
- rows, cols = f.read(8).unpack("N2")
61
- images = Numo::UInt8.from_binary(f.read)
62
- images = images.reshape(num_images, cols, rows)
63
- end
64
- images
65
- end
66
-
67
- def self.load_labels(file_name)
68
- labels = nil
69
- Zlib::GzipReader.open(file_name) do |f|
70
- magic, num_labels = f.read(8).unpack("N2")
71
- labels = Numo::UInt8.from_binary(f.read)
72
- end
73
- labels
74
- end
75
-
76
- def self.mnist_dir
77
- "#{__dir__}/downloads/mnist"
78
- end
79
-
80
- def self.url_to_file_name(url)
81
- mnist_dir + "/" + url.match(%r`.+/(.+)$`)[1]
82
- end
83
- end
84
- end
1
+ require "zlib"
2
+ require_relative "core/error"
3
+ require_relative "downloader"
4
+
5
+ module DNN
6
+ module MNIST
7
+ class DNN_MNIST_LoadError < DNN_Error; end
8
+
9
+ URL_TRAIN_IMAGES = "http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz"
10
+ URL_TRAIN_LABELS = "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz"
11
+ URL_TEST_IMAGES = "http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz"
12
+ URL_TEST_LABELS = "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz"
13
+
14
+ def self.downloads
15
+ return if Dir.exist?(mnist_dir)
16
+ Dir.mkdir("#{__dir__}/downloads") unless Dir.exist?("#{__dir__}/downloads")
17
+ Dir.mkdir(mnist_dir)
18
+ Downloader.download(URL_TRAIN_IMAGES, mnist_dir)
19
+ Downloader.download(URL_TRAIN_LABELS, mnist_dir)
20
+ Downloader.download(URL_TEST_IMAGES, mnist_dir)
21
+ Downloader.download(URL_TEST_LABELS, mnist_dir)
22
+ end
23
+
24
+ def self.load_train
25
+ downloads
26
+ train_images_file_name = url_to_file_name(URL_TRAIN_IMAGES)
27
+ train_labels_file_name = url_to_file_name(URL_TRAIN_LABELS)
28
+ unless File.exist?(train_images_file_name)
29
+ raise DNN_MNIST_LoadError.new(%`file "#{train_images_file_name}" is not found.`)
30
+ end
31
+ unless File.exist?(train_labels_file_name)
32
+ raise DNN_MNIST_LoadError.new(%`file "#{train_labels_file_name}" is not found.`)
33
+ end
34
+ images = load_images(train_images_file_name)
35
+ labels = load_labels(train_labels_file_name)
36
+ [images, labels]
37
+ end
38
+
39
+ def self.load_test
40
+ downloads
41
+ test_images_file_name = url_to_file_name(URL_TEST_IMAGES)
42
+ test_labels_file_name = url_to_file_name(URL_TEST_LABELS)
43
+ unless File.exist?(test_images_file_name)
44
+ raise DNN_MNIST_LoadError.new(%`file "#{train_images_file_name}" is not found.`)
45
+ end
46
+ unless File.exist?(test_labels_file_name)
47
+ raise DNN_MNIST_LoadError.new(%`file "#{train_labels_file_name}" is not found.`)
48
+ end
49
+ images = load_images(test_images_file_name)
50
+ labels = load_labels(test_labels_file_name)
51
+ [images, labels]
52
+ end
53
+
54
+ private_class_method
55
+
56
+ def self.load_images(file_name)
57
+ images = nil
58
+ Zlib::GzipReader.open(file_name) do |f|
59
+ magic, num_images = f.read(8).unpack("N2")
60
+ rows, cols = f.read(8).unpack("N2")
61
+ images = Numo::UInt8.from_binary(f.read)
62
+ images = images.reshape(num_images, cols, rows)
63
+ end
64
+ images
65
+ end
66
+
67
+ def self.load_labels(file_name)
68
+ labels = nil
69
+ Zlib::GzipReader.open(file_name) do |f|
70
+ magic, num_labels = f.read(8).unpack("N2")
71
+ labels = Numo::UInt8.from_binary(f.read)
72
+ end
73
+ labels
74
+ end
75
+
76
+ def self.mnist_dir
77
+ "#{__dir__}/downloads/mnist"
78
+ end
79
+
80
+ def self.url_to_file_name(url)
81
+ mnist_dir + "/" + url.match(%r`.+/(.+)$`)[1]
82
+ end
83
+ end
84
+ end