ruby-dnn 0.10.1 → 0.10.2
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/README.md +2 -2
- data/examples/cifar100_example.rb +71 -71
- data/examples/cifar10_example.rb +68 -68
- data/examples/iris_example.rb +34 -34
- data/examples/mnist_conv2d_example.rb +50 -50
- data/examples/mnist_example.rb +39 -39
- data/examples/mnist_lstm_example.rb +36 -36
- data/examples/xor_example.rb +24 -24
- data/lib/dnn.rb +27 -26
- data/lib/dnn/cifar10.rb +51 -51
- data/lib/dnn/cifar100.rb +49 -49
- data/lib/dnn/core/activations.rb +148 -148
- data/lib/dnn/core/cnn_layers.rb +464 -464
- data/lib/dnn/core/dataset.rb +34 -34
- data/lib/dnn/core/embedding.rb +56 -0
- data/lib/dnn/core/error.rb +5 -5
- data/lib/dnn/core/initializers.rb +126 -126
- data/lib/dnn/core/layers.rb +307 -307
- data/lib/dnn/core/losses.rb +175 -175
- data/lib/dnn/core/model.rb +461 -461
- data/lib/dnn/core/normalizations.rb +72 -72
- data/lib/dnn/core/optimizers.rb +283 -283
- data/lib/dnn/core/param.rb +9 -9
- data/lib/dnn/core/regularizers.rb +106 -106
- data/lib/dnn/core/rnn_layers.rb +464 -464
- data/lib/dnn/core/utils.rb +34 -34
- data/lib/dnn/downloader.rb +50 -50
- data/lib/dnn/image.rb +41 -41
- data/lib/dnn/iris.rb +60 -60
- data/lib/dnn/mnist.rb +84 -84
- data/lib/dnn/version.rb +3 -3
- metadata +2 -1
data/examples/mnist_example.rb
CHANGED
@@ -1,39 +1,39 @@
|
|
1
|
-
require "dnn"
|
2
|
-
require "dnn/mnist"
|
3
|
-
# If you use numo/linalg then please uncomment out.
|
4
|
-
# require "numo/linalg/autoloader"
|
5
|
-
|
6
|
-
include DNN::Layers
|
7
|
-
include DNN::Activations
|
8
|
-
include DNN::Optimizers
|
9
|
-
include DNN::Losses
|
10
|
-
Model = DNN::Model
|
11
|
-
MNIST = DNN::MNIST
|
12
|
-
|
13
|
-
x_train, y_train = MNIST.load_train
|
14
|
-
x_test, y_test = MNIST.load_test
|
15
|
-
|
16
|
-
x_train = Numo::SFloat.cast(x_train).reshape(x_train.shape[0], 784)
|
17
|
-
x_test = Numo::SFloat.cast(x_test).reshape(x_test.shape[0], 784)
|
18
|
-
|
19
|
-
x_train /= 255
|
20
|
-
x_test /= 255
|
21
|
-
|
22
|
-
y_train = DNN::Utils.to_categorical(y_train, 10, Numo::SFloat)
|
23
|
-
y_test = DNN::Utils.to_categorical(y_test, 10, Numo::SFloat)
|
24
|
-
|
25
|
-
model = Model.new
|
26
|
-
|
27
|
-
model << InputLayer.new(784)
|
28
|
-
|
29
|
-
model << Dense.new(256)
|
30
|
-
model << ReLU.new
|
31
|
-
|
32
|
-
model << Dense.new(256)
|
33
|
-
model << ReLU.new
|
34
|
-
|
35
|
-
model << Dense.new(10)
|
36
|
-
|
37
|
-
model.compile(RMSProp.new, SoftmaxCrossEntropy.new)
|
38
|
-
|
39
|
-
model.train(x_train, y_train, 10, batch_size: 100, test: [x_test, y_test])
|
1
|
+
require "dnn"
|
2
|
+
require "dnn/mnist"
|
3
|
+
# If you use numo/linalg then please uncomment out.
|
4
|
+
# require "numo/linalg/autoloader"
|
5
|
+
|
6
|
+
include DNN::Layers
|
7
|
+
include DNN::Activations
|
8
|
+
include DNN::Optimizers
|
9
|
+
include DNN::Losses
|
10
|
+
Model = DNN::Model
|
11
|
+
MNIST = DNN::MNIST
|
12
|
+
|
13
|
+
x_train, y_train = MNIST.load_train
|
14
|
+
x_test, y_test = MNIST.load_test
|
15
|
+
|
16
|
+
x_train = Numo::SFloat.cast(x_train).reshape(x_train.shape[0], 784)
|
17
|
+
x_test = Numo::SFloat.cast(x_test).reshape(x_test.shape[0], 784)
|
18
|
+
|
19
|
+
x_train /= 255
|
20
|
+
x_test /= 255
|
21
|
+
|
22
|
+
y_train = DNN::Utils.to_categorical(y_train, 10, Numo::SFloat)
|
23
|
+
y_test = DNN::Utils.to_categorical(y_test, 10, Numo::SFloat)
|
24
|
+
|
25
|
+
model = Model.new
|
26
|
+
|
27
|
+
model << InputLayer.new(784)
|
28
|
+
|
29
|
+
model << Dense.new(256)
|
30
|
+
model << ReLU.new
|
31
|
+
|
32
|
+
model << Dense.new(256)
|
33
|
+
model << ReLU.new
|
34
|
+
|
35
|
+
model << Dense.new(10)
|
36
|
+
|
37
|
+
model.compile(RMSProp.new, SoftmaxCrossEntropy.new)
|
38
|
+
|
39
|
+
model.train(x_train, y_train, 10, batch_size: 100, test: [x_test, y_test])
|
@@ -1,36 +1,36 @@
|
|
1
|
-
require "dnn"
|
2
|
-
require "dnn/mnist"
|
3
|
-
# If you use numo/linalg then please uncomment out.
|
4
|
-
# require "numo/linalg/autoloader"
|
5
|
-
|
6
|
-
include DNN::Layers
|
7
|
-
include DNN::Activations
|
8
|
-
include DNN::Optimizers
|
9
|
-
include DNN::Losses
|
10
|
-
Model = DNN::Model
|
11
|
-
MNIST = DNN::MNIST
|
12
|
-
|
13
|
-
x_train, y_train = MNIST.load_train
|
14
|
-
x_test, y_test = MNIST.load_test
|
15
|
-
|
16
|
-
x_train = Numo::SFloat.cast(x_train).reshape(x_train.shape[0], 28, 28)
|
17
|
-
x_test = Numo::SFloat.cast(x_test).reshape(x_test.shape[0], 28, 28)
|
18
|
-
|
19
|
-
x_train /= 255
|
20
|
-
x_test /= 255
|
21
|
-
|
22
|
-
y_train = DNN::Utils.to_categorical(y_train, 10, Numo::SFloat)
|
23
|
-
y_test = DNN::Utils.to_categorical(y_test, 10, Numo::SFloat)
|
24
|
-
|
25
|
-
model = Model.new
|
26
|
-
|
27
|
-
model << InputLayer.new([28, 28])
|
28
|
-
|
29
|
-
model << LSTM.new(200)
|
30
|
-
model << LSTM.new(200, return_sequences: false)
|
31
|
-
|
32
|
-
model << Dense.new(10)
|
33
|
-
|
34
|
-
model.compile(Adam.new, SoftmaxCrossEntropy.new)
|
35
|
-
|
36
|
-
model.train(x_train, y_train, 10, batch_size: 100, test: [x_test, y_test])
|
1
|
+
require "dnn"
|
2
|
+
require "dnn/mnist"
|
3
|
+
# If you use numo/linalg then please uncomment out.
|
4
|
+
# require "numo/linalg/autoloader"
|
5
|
+
|
6
|
+
include DNN::Layers
|
7
|
+
include DNN::Activations
|
8
|
+
include DNN::Optimizers
|
9
|
+
include DNN::Losses
|
10
|
+
Model = DNN::Model
|
11
|
+
MNIST = DNN::MNIST
|
12
|
+
|
13
|
+
x_train, y_train = MNIST.load_train
|
14
|
+
x_test, y_test = MNIST.load_test
|
15
|
+
|
16
|
+
x_train = Numo::SFloat.cast(x_train).reshape(x_train.shape[0], 28, 28)
|
17
|
+
x_test = Numo::SFloat.cast(x_test).reshape(x_test.shape[0], 28, 28)
|
18
|
+
|
19
|
+
x_train /= 255
|
20
|
+
x_test /= 255
|
21
|
+
|
22
|
+
y_train = DNN::Utils.to_categorical(y_train, 10, Numo::SFloat)
|
23
|
+
y_test = DNN::Utils.to_categorical(y_test, 10, Numo::SFloat)
|
24
|
+
|
25
|
+
model = Model.new
|
26
|
+
|
27
|
+
model << InputLayer.new([28, 28])
|
28
|
+
|
29
|
+
model << LSTM.new(200)
|
30
|
+
model << LSTM.new(200, return_sequences: false)
|
31
|
+
|
32
|
+
model << Dense.new(10)
|
33
|
+
|
34
|
+
model.compile(Adam.new, SoftmaxCrossEntropy.new)
|
35
|
+
|
36
|
+
model.train(x_train, y_train, 10, batch_size: 100, test: [x_test, y_test])
|
data/examples/xor_example.rb
CHANGED
@@ -1,24 +1,24 @@
|
|
1
|
-
require "dnn"
|
2
|
-
|
3
|
-
include DNN::Layers
|
4
|
-
include DNN::Activations
|
5
|
-
include DNN::Optimizers
|
6
|
-
include DNN::Losses
|
7
|
-
Model = DNN::Model
|
8
|
-
Utils = DNN::Utils
|
9
|
-
|
10
|
-
x = Numo::SFloat[[0, 0], [1, 0], [0, 1], [1, 1]]
|
11
|
-
y = Numo::SFloat[[0], [1], [1], [0]]
|
12
|
-
|
13
|
-
model = Model.new
|
14
|
-
|
15
|
-
model << InputLayer.new(2)
|
16
|
-
model << Dense.new(16)
|
17
|
-
model << ReLU.new
|
18
|
-
model << Dense.new(1)
|
19
|
-
|
20
|
-
model.compile(SGD.new, SigmoidCrossEntropy.new)
|
21
|
-
|
22
|
-
model.train(x, y, 20000, batch_size: 4, verbose: false)
|
23
|
-
|
24
|
-
p Utils.sigmoid(model.predict(x))
|
1
|
+
require "dnn"
|
2
|
+
|
3
|
+
include DNN::Layers
|
4
|
+
include DNN::Activations
|
5
|
+
include DNN::Optimizers
|
6
|
+
include DNN::Losses
|
7
|
+
Model = DNN::Model
|
8
|
+
Utils = DNN::Utils
|
9
|
+
|
10
|
+
x = Numo::SFloat[[0, 0], [1, 0], [0, 1], [1, 1]]
|
11
|
+
y = Numo::SFloat[[0], [1], [1], [0]]
|
12
|
+
|
13
|
+
model = Model.new
|
14
|
+
|
15
|
+
model << InputLayer.new(2)
|
16
|
+
model << Dense.new(16)
|
17
|
+
model << ReLU.new
|
18
|
+
model << Dense.new(1)
|
19
|
+
|
20
|
+
model.compile(SGD.new, SigmoidCrossEntropy.new)
|
21
|
+
|
22
|
+
model.train(x, y, 20000, batch_size: 4, verbose: false)
|
23
|
+
|
24
|
+
p Utils.sigmoid(model.predict(x))
|
data/lib/dnn.rb
CHANGED
@@ -1,26 +1,27 @@
|
|
1
|
-
if defined? Cumo
|
2
|
-
Xumo = Cumo
|
3
|
-
else
|
4
|
-
require "numo/narray"
|
5
|
-
Xumo = Numo
|
6
|
-
end
|
7
|
-
|
8
|
-
module DNN
|
9
|
-
NMath = Xumo::NMath
|
10
|
-
end
|
11
|
-
|
12
|
-
require_relative "dnn/version"
|
13
|
-
require_relative "dnn/core/error"
|
14
|
-
require_relative "dnn/core/model"
|
15
|
-
require_relative "dnn/core/param"
|
16
|
-
require_relative "dnn/core/dataset"
|
17
|
-
require_relative "dnn/core/initializers"
|
18
|
-
require_relative "dnn/core/layers"
|
19
|
-
require_relative "dnn/core/normalizations"
|
20
|
-
require_relative "dnn/core/activations"
|
21
|
-
require_relative "dnn/core/losses"
|
22
|
-
require_relative "dnn/core/regularizers"
|
23
|
-
require_relative "dnn/core/cnn_layers"
|
24
|
-
require_relative "dnn/core/
|
25
|
-
require_relative "dnn/core/
|
26
|
-
require_relative "dnn/core/
|
1
|
+
if defined? Cumo
|
2
|
+
Xumo = Cumo
|
3
|
+
else
|
4
|
+
require "numo/narray"
|
5
|
+
Xumo = Numo
|
6
|
+
end
|
7
|
+
|
8
|
+
module DNN
|
9
|
+
NMath = Xumo::NMath
|
10
|
+
end
|
11
|
+
|
12
|
+
require_relative "dnn/version"
|
13
|
+
require_relative "dnn/core/error"
|
14
|
+
require_relative "dnn/core/model"
|
15
|
+
require_relative "dnn/core/param"
|
16
|
+
require_relative "dnn/core/dataset"
|
17
|
+
require_relative "dnn/core/initializers"
|
18
|
+
require_relative "dnn/core/layers"
|
19
|
+
require_relative "dnn/core/normalizations"
|
20
|
+
require_relative "dnn/core/activations"
|
21
|
+
require_relative "dnn/core/losses"
|
22
|
+
require_relative "dnn/core/regularizers"
|
23
|
+
require_relative "dnn/core/cnn_layers"
|
24
|
+
require_relative "dnn/core/embedding"
|
25
|
+
require_relative "dnn/core/rnn_layers"
|
26
|
+
require_relative "dnn/core/optimizers"
|
27
|
+
require_relative "dnn/core/utils"
|
data/lib/dnn/cifar10.rb
CHANGED
@@ -1,51 +1,51 @@
|
|
1
|
-
require "zlib"
|
2
|
-
require "archive/tar/minitar"
|
3
|
-
require_relative "../../ext/cifar_loader/cifar_loader"
|
4
|
-
require_relative "downloader"
|
5
|
-
|
6
|
-
URL_CIFAR10 = "https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz"
|
7
|
-
DIR_CIFAR10 = "cifar-10-batches-bin"
|
8
|
-
|
9
|
-
module DNN
|
10
|
-
module CIFAR10
|
11
|
-
class DNN_CIFAR10_LoadError < DNN_Error; end
|
12
|
-
|
13
|
-
def self.downloads
|
14
|
-
return if Dir.exist?(__dir__ + "/downloads/" + DIR_CIFAR10)
|
15
|
-
Downloader.download(URL_CIFAR10)
|
16
|
-
cifar10_binary_file_name = __dir__ + "/downloads/" + URL_CIFAR10.match(%r`.+/(.+)`)[1]
|
17
|
-
begin
|
18
|
-
Zlib::GzipReader.open(cifar10_binary_file_name) do |gz|
|
19
|
-
Archive::Tar::Minitar::unpack(gz, __dir__ + "/downloads")
|
20
|
-
end
|
21
|
-
ensure
|
22
|
-
File.unlink(cifar10_binary_file_name)
|
23
|
-
end
|
24
|
-
end
|
25
|
-
|
26
|
-
def self.load_train
|
27
|
-
downloads
|
28
|
-
bin = ""
|
29
|
-
(1..5).each do |i|
|
30
|
-
fname = __dir__ + "/downloads/#{DIR_CIFAR10}/data_batch_#{i}.bin"
|
31
|
-
raise DNN_CIFAR10_LoadError.new(%`file "#{fname}" is not found.`) unless File.exist?(fname)
|
32
|
-
bin << File.binread(fname)
|
33
|
-
end
|
34
|
-
x_bin, y_bin = CIFAR10.load_binary(bin, 50000)
|
35
|
-
x_train = Numo::UInt8.from_binary(x_bin).reshape(50000, 3, 32, 32).transpose(0, 2, 3, 1).clone
|
36
|
-
y_train = Numo::UInt8.from_binary(y_bin)
|
37
|
-
[x_train, y_train]
|
38
|
-
end
|
39
|
-
|
40
|
-
def self.load_test
|
41
|
-
downloads
|
42
|
-
fname = __dir__ + "/downloads/#{DIR_CIFAR10}/test_batch.bin"
|
43
|
-
raise DNN_CIFAR10_LoadError.new(%`file "#{fname}" is not found.`) unless File.exist?(fname)
|
44
|
-
bin = File.binread(fname)
|
45
|
-
x_bin, y_bin = CIFAR10.load_binary(bin, 10000)
|
46
|
-
x_test = Numo::UInt8.from_binary(x_bin).reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).clone
|
47
|
-
y_test = Numo::UInt8.from_binary(y_bin)
|
48
|
-
[x_test, y_test]
|
49
|
-
end
|
50
|
-
end
|
51
|
-
end
|
1
|
+
require "zlib"
|
2
|
+
require "archive/tar/minitar"
|
3
|
+
require_relative "../../ext/cifar_loader/cifar_loader"
|
4
|
+
require_relative "downloader"
|
5
|
+
|
6
|
+
URL_CIFAR10 = "https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz"
|
7
|
+
DIR_CIFAR10 = "cifar-10-batches-bin"
|
8
|
+
|
9
|
+
module DNN
|
10
|
+
module CIFAR10
|
11
|
+
class DNN_CIFAR10_LoadError < DNN_Error; end
|
12
|
+
|
13
|
+
def self.downloads
|
14
|
+
return if Dir.exist?(__dir__ + "/downloads/" + DIR_CIFAR10)
|
15
|
+
Downloader.download(URL_CIFAR10)
|
16
|
+
cifar10_binary_file_name = __dir__ + "/downloads/" + URL_CIFAR10.match(%r`.+/(.+)`)[1]
|
17
|
+
begin
|
18
|
+
Zlib::GzipReader.open(cifar10_binary_file_name) do |gz|
|
19
|
+
Archive::Tar::Minitar::unpack(gz, __dir__ + "/downloads")
|
20
|
+
end
|
21
|
+
ensure
|
22
|
+
File.unlink(cifar10_binary_file_name)
|
23
|
+
end
|
24
|
+
end
|
25
|
+
|
26
|
+
def self.load_train
|
27
|
+
downloads
|
28
|
+
bin = ""
|
29
|
+
(1..5).each do |i|
|
30
|
+
fname = __dir__ + "/downloads/#{DIR_CIFAR10}/data_batch_#{i}.bin"
|
31
|
+
raise DNN_CIFAR10_LoadError.new(%`file "#{fname}" is not found.`) unless File.exist?(fname)
|
32
|
+
bin << File.binread(fname)
|
33
|
+
end
|
34
|
+
x_bin, y_bin = CIFAR10.load_binary(bin, 50000)
|
35
|
+
x_train = Numo::UInt8.from_binary(x_bin).reshape(50000, 3, 32, 32).transpose(0, 2, 3, 1).clone
|
36
|
+
y_train = Numo::UInt8.from_binary(y_bin)
|
37
|
+
[x_train, y_train]
|
38
|
+
end
|
39
|
+
|
40
|
+
def self.load_test
|
41
|
+
downloads
|
42
|
+
fname = __dir__ + "/downloads/#{DIR_CIFAR10}/test_batch.bin"
|
43
|
+
raise DNN_CIFAR10_LoadError.new(%`file "#{fname}" is not found.`) unless File.exist?(fname)
|
44
|
+
bin = File.binread(fname)
|
45
|
+
x_bin, y_bin = CIFAR10.load_binary(bin, 10000)
|
46
|
+
x_test = Numo::UInt8.from_binary(x_bin).reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).clone
|
47
|
+
y_test = Numo::UInt8.from_binary(y_bin)
|
48
|
+
[x_test, y_test]
|
49
|
+
end
|
50
|
+
end
|
51
|
+
end
|
data/lib/dnn/cifar100.rb
CHANGED
@@ -1,49 +1,49 @@
|
|
1
|
-
require "zlib"
|
2
|
-
require "archive/tar/minitar"
|
3
|
-
require_relative "../../ext/cifar_loader/cifar_loader"
|
4
|
-
require_relative "downloader"
|
5
|
-
|
6
|
-
URL_CIFAR100 = "https://www.cs.toronto.edu/~kriz/cifar-100-binary.tar.gz"
|
7
|
-
DIR_CIFAR100 = "cifar-100-binary"
|
8
|
-
|
9
|
-
module DNN
|
10
|
-
module CIFAR100
|
11
|
-
class DNN_CIFAR100_LoadError < DNN_Error; end
|
12
|
-
|
13
|
-
def self.downloads
|
14
|
-
return if Dir.exist?(__dir__ + "/downloads/" + DIR_CIFAR100)
|
15
|
-
Downloader.download(URL_CIFAR100)
|
16
|
-
cifar100_binary_file_name = __dir__ + "/downloads/" + URL_CIFAR100.match(%r`.+/(.+)`)[1]
|
17
|
-
begin
|
18
|
-
Zlib::GzipReader.open(cifar100_binary_file_name) do |gz|
|
19
|
-
Archive::Tar::Minitar::unpack(gz, __dir__ + "/downloads")
|
20
|
-
end
|
21
|
-
ensure
|
22
|
-
File.unlink(cifar100_binary_file_name)
|
23
|
-
end
|
24
|
-
end
|
25
|
-
|
26
|
-
def self.load_train
|
27
|
-
downloads
|
28
|
-
bin = ""
|
29
|
-
fname = __dir__ + "/downloads/#{DIR_CIFAR100}/train.bin"
|
30
|
-
raise DNN_CIFAR100_LoadError.new(%`file "#{fname}" is not found.`) unless File.exist?(fname)
|
31
|
-
bin << File.binread(fname)
|
32
|
-
x_bin, y_bin = CIFAR100.load_binary(bin, 50000)
|
33
|
-
x_train = Numo::UInt8.from_binary(x_bin).reshape(50000, 3, 32, 32).transpose(0, 2, 3, 1).clone
|
34
|
-
y_train = Numo::UInt8.from_binary(y_bin).reshape(50000, 2)
|
35
|
-
[x_train, y_train]
|
36
|
-
end
|
37
|
-
|
38
|
-
def self.load_test
|
39
|
-
downloads
|
40
|
-
fname = __dir__ + "/downloads/#{DIR_CIFAR100}/test.bin"
|
41
|
-
raise DNN_CIFAR100_LoadError.new(%`file "#{fname}" is not found.`) unless File.exist?(fname)
|
42
|
-
bin = File.binread(fname)
|
43
|
-
x_bin, y_bin = CIFAR100.load_binary(bin, 10000)
|
44
|
-
x_test = Numo::UInt8.from_binary(x_bin).reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).clone
|
45
|
-
y_test = Numo::UInt8.from_binary(y_bin).reshape(10000, 2)
|
46
|
-
[x_test, y_test]
|
47
|
-
end
|
48
|
-
end
|
49
|
-
end
|
1
|
+
require "zlib"
|
2
|
+
require "archive/tar/minitar"
|
3
|
+
require_relative "../../ext/cifar_loader/cifar_loader"
|
4
|
+
require_relative "downloader"
|
5
|
+
|
6
|
+
URL_CIFAR100 = "https://www.cs.toronto.edu/~kriz/cifar-100-binary.tar.gz"
|
7
|
+
DIR_CIFAR100 = "cifar-100-binary"
|
8
|
+
|
9
|
+
module DNN
|
10
|
+
module CIFAR100
|
11
|
+
class DNN_CIFAR100_LoadError < DNN_Error; end
|
12
|
+
|
13
|
+
def self.downloads
|
14
|
+
return if Dir.exist?(__dir__ + "/downloads/" + DIR_CIFAR100)
|
15
|
+
Downloader.download(URL_CIFAR100)
|
16
|
+
cifar100_binary_file_name = __dir__ + "/downloads/" + URL_CIFAR100.match(%r`.+/(.+)`)[1]
|
17
|
+
begin
|
18
|
+
Zlib::GzipReader.open(cifar100_binary_file_name) do |gz|
|
19
|
+
Archive::Tar::Minitar::unpack(gz, __dir__ + "/downloads")
|
20
|
+
end
|
21
|
+
ensure
|
22
|
+
File.unlink(cifar100_binary_file_name)
|
23
|
+
end
|
24
|
+
end
|
25
|
+
|
26
|
+
def self.load_train
|
27
|
+
downloads
|
28
|
+
bin = ""
|
29
|
+
fname = __dir__ + "/downloads/#{DIR_CIFAR100}/train.bin"
|
30
|
+
raise DNN_CIFAR100_LoadError.new(%`file "#{fname}" is not found.`) unless File.exist?(fname)
|
31
|
+
bin << File.binread(fname)
|
32
|
+
x_bin, y_bin = CIFAR100.load_binary(bin, 50000)
|
33
|
+
x_train = Numo::UInt8.from_binary(x_bin).reshape(50000, 3, 32, 32).transpose(0, 2, 3, 1).clone
|
34
|
+
y_train = Numo::UInt8.from_binary(y_bin).reshape(50000, 2)
|
35
|
+
[x_train, y_train]
|
36
|
+
end
|
37
|
+
|
38
|
+
def self.load_test
|
39
|
+
downloads
|
40
|
+
fname = __dir__ + "/downloads/#{DIR_CIFAR100}/test.bin"
|
41
|
+
raise DNN_CIFAR100_LoadError.new(%`file "#{fname}" is not found.`) unless File.exist?(fname)
|
42
|
+
bin = File.binread(fname)
|
43
|
+
x_bin, y_bin = CIFAR100.load_binary(bin, 10000)
|
44
|
+
x_test = Numo::UInt8.from_binary(x_bin).reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).clone
|
45
|
+
y_test = Numo::UInt8.from_binary(y_bin).reshape(10000, 2)
|
46
|
+
[x_test, y_test]
|
47
|
+
end
|
48
|
+
end
|
49
|
+
end
|