ruby-dnn 0.9.4 → 0.10.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/README.md +39 -3
- data/Rakefile +6 -0
- data/examples/cifar100_example.rb +71 -0
- data/examples/cifar10_example.rb +2 -1
- data/examples/iris_example.rb +2 -1
- data/examples/mnist_conv2d_example.rb +2 -1
- data/examples/mnist_example.rb +2 -3
- data/examples/mnist_lstm_example.rb +2 -1
- data/ext/cifar_loader/cifar_loader.c +77 -0
- data/ext/cifar_loader/extconf.rb +3 -0
- data/lib/dnn.rb +1 -0
- data/lib/dnn/{lib/cifar10.rb → cifar10.rb} +9 -11
- data/lib/dnn/cifar100.rb +49 -0
- data/lib/dnn/core/activations.rb +28 -24
- data/lib/dnn/core/cnn_layers.rb +216 -94
- data/lib/dnn/core/dataset.rb +21 -5
- data/lib/dnn/core/initializers.rb +3 -3
- data/lib/dnn/core/layers.rb +81 -150
- data/lib/dnn/core/losses.rb +88 -49
- data/lib/dnn/core/model.rb +97 -74
- data/lib/dnn/core/normalizations.rb +72 -0
- data/lib/dnn/core/optimizers.rb +171 -78
- data/lib/dnn/core/regularizers.rb +92 -22
- data/lib/dnn/core/rnn_layers.rb +146 -121
- data/lib/dnn/core/utils.rb +4 -3
- data/lib/dnn/{lib/downloader.rb → downloader.rb} +5 -1
- data/lib/dnn/{lib/image.rb → image.rb} +1 -1
- data/lib/dnn/{lib/iris.rb → iris.rb} +1 -1
- data/lib/dnn/{lib/mnist.rb → mnist.rb} +4 -3
- data/lib/dnn/version.rb +1 -1
- data/ruby-dnn.gemspec +1 -1
- metadata +13 -12
- data/API-Reference.ja.md +0 -978
- data/LIB-API-Reference.ja.md +0 -97
- data/ext/cifar10_loader/cifar10_loader.c +0 -44
- data/ext/cifar10_loader/extconf.rb +0 -3
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 9fbaca2827a3877b9c215d3a93b7c0e20a4a456471541e3e3f0140a95abf695b
|
4
|
+
data.tar.gz: 0e57474b603a0ba4d08d66176fea34b77e9dbc0c51639635bb861779a5bd047e
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: e02bcedc98d20dda9c4aae46955af777586cb433ca1d4f479248e56b2ec08a38fbb1b1023c644620664c40fda420a4584d98c1dd0080ab29212c0a0176b669d7
|
7
|
+
data.tar.gz: bb1486c4a85edbb5b5aed8e8feaffd15937893883d5d31b831427e7bcf663e4157452dd565f0bf4909b444e1241213f8f0a682f56d03d267954e66b9d312e7ab
|
data/README.md
CHANGED
@@ -1,4 +1,5 @@
|
|
1
1
|
# ruby-dnn
|
2
|
+
[![Gem Version](https://badge.fury.io/rb/ruby-dnn.svg)](https://badge.fury.io/rb/ruby-dnn)
|
2
3
|
|
3
4
|
ruby-dnn is a ruby deep learning library. This library supports full connected neural network and convolution neural network.
|
4
5
|
Currently, you can get 99% accuracy with MNIST and 74% with CIFAR 10.
|
@@ -21,10 +22,45 @@ Or install it yourself as:
|
|
21
22
|
|
22
23
|
## Usage
|
23
24
|
|
24
|
-
|
25
|
-
|
25
|
+
### MNIST MLP example
|
26
|
+
|
27
|
+
```ruby
|
28
|
+
model = Model.new
|
29
|
+
|
30
|
+
model << InputLayer.new(784)
|
31
|
+
|
32
|
+
model << Dense.new(256)
|
33
|
+
model << ReLU.new
|
34
|
+
|
35
|
+
model << Dense.new(256)
|
36
|
+
model << ReLU.new
|
37
|
+
|
38
|
+
model << Dense.new(10)
|
39
|
+
|
40
|
+
model.compile(RMSProp.new, SoftmaxCrossEntropy.new)
|
41
|
+
|
42
|
+
model.train(x_train, y_train, 10, batch_size: 100, test: [x_test, y_test])
|
43
|
+
|
44
|
+
```
|
45
|
+
|
46
|
+
Please refer to examples for basic usage.
|
26
47
|
If you want to know more detailed information, please refer to the source code.
|
27
48
|
|
49
|
+
## Implemented
|
50
|
+
|| Implemented classes |
|
51
|
+
|:-----------|------------:|
|
52
|
+
| Connections | Dense, Conv2D, Conv2D_Transpose, SimpleRNN, LSTM, GRU |
|
53
|
+
| Layers | Flatten, Reshape, Dropout, BatchNormalization, MaxPool2D, AvgPool2D, UnPool2D |
|
54
|
+
| Activations | Sigmoid, Tanh, Softsign, Softplus, Swish, ReLU, LeakyReLU, ELU |
|
55
|
+
| Optimizers | SGD, Nesterov, AdaGrad, RMSProp, AdaDelta, Adam, RMSPropGraves |
|
56
|
+
| Losses | MeanSquaredError, MeanAbsoluteError, HuberLoss, SoftmaxCrossEntropy, SigmoidCrossEntropy |
|
57
|
+
|
58
|
+
## TODO
|
59
|
+
● Add CI badge.
|
60
|
+
● Write a test.
|
61
|
+
● Write a document.
|
62
|
+
● Support to GPU.
|
63
|
+
|
28
64
|
## Development
|
29
65
|
|
30
66
|
After checking out the repo, run `bin/setup` to install dependencies. Then, run `rake "spec"` to run the tests. You can also run `bin/console` for an interactive prompt that will allow you to experiment.
|
@@ -41,4 +77,4 @@ The gem is available as open source under the terms of the [MIT License](https:/
|
|
41
77
|
|
42
78
|
## Code of Conduct
|
43
79
|
|
44
|
-
Everyone interacting in the
|
80
|
+
Everyone interacting in the ruby-dnn project’s codebases, issue trackers, chat rooms and mailing lists is expected to follow the [code of conduct](https://github.com/[USERNAME]/dnn/blob/master/CODE_OF_CONDUCT.md).
|
data/Rakefile
CHANGED
@@ -0,0 +1,71 @@
|
|
1
|
+
require "dnn"
|
2
|
+
require "dnn/cifar100"
|
3
|
+
# If you use numo/linalg then please uncomment out.
|
4
|
+
# require "numo/linalg/autoloader"
|
5
|
+
|
6
|
+
include DNN::Layers
|
7
|
+
include DNN::Activations
|
8
|
+
include DNN::Optimizers
|
9
|
+
include DNN::Losses
|
10
|
+
Model = DNN::Model
|
11
|
+
CIFAR100 = DNN::CIFAR100
|
12
|
+
|
13
|
+
x_train, y_train = CIFAR100.load_train
|
14
|
+
x_test, y_test = CIFAR100.load_test
|
15
|
+
|
16
|
+
x_train = Numo::SFloat.cast(x_train)
|
17
|
+
x_test = Numo::SFloat.cast(x_test)
|
18
|
+
|
19
|
+
x_train /= 255
|
20
|
+
x_test /= 255
|
21
|
+
|
22
|
+
y_train = y_train[true, 1]
|
23
|
+
y_test = y_test[true, 1]
|
24
|
+
|
25
|
+
y_train = DNN::Utils.to_categorical(y_train, 100, Numo::SFloat)
|
26
|
+
y_test = DNN::Utils.to_categorical(y_test, 100, Numo::SFloat)
|
27
|
+
|
28
|
+
model = Model.new
|
29
|
+
|
30
|
+
model << InputLayer.new([32, 32, 3])
|
31
|
+
|
32
|
+
model << Conv2D.new(16, 5, padding: true)
|
33
|
+
model << BatchNormalization.new
|
34
|
+
model << ReLU.new
|
35
|
+
|
36
|
+
model << Conv2D.new(16, 5, padding: true)
|
37
|
+
model << BatchNormalization.new
|
38
|
+
model << ReLU.new
|
39
|
+
|
40
|
+
model << MaxPool2D.new(2)
|
41
|
+
|
42
|
+
model << Conv2D.new(32, 5, padding: true)
|
43
|
+
model << BatchNormalization.new
|
44
|
+
model << ReLU.new
|
45
|
+
|
46
|
+
model << Conv2D.new(32, 5, padding: true)
|
47
|
+
model << BatchNormalization.new
|
48
|
+
model << ReLU.new
|
49
|
+
|
50
|
+
model << MaxPool2D.new(2)
|
51
|
+
|
52
|
+
model << Conv2D.new(64, 5, padding: true)
|
53
|
+
model << BatchNormalization.new
|
54
|
+
model << ReLU.new
|
55
|
+
|
56
|
+
model << Conv2D.new(64, 5, padding: true)
|
57
|
+
model << BatchNormalization.new
|
58
|
+
model << ReLU.new
|
59
|
+
|
60
|
+
model << Flatten.new
|
61
|
+
|
62
|
+
model << Dense.new(1024)
|
63
|
+
model << BatchNormalization.new
|
64
|
+
model << ReLU.new
|
65
|
+
model << Dropout.new(0.5)
|
66
|
+
|
67
|
+
model << Dense.new(100)
|
68
|
+
|
69
|
+
model.compile(Adam.new, SoftmaxCrossEntropy.new)
|
70
|
+
|
71
|
+
model.train(x_train, y_train, 10, batch_size: 100, test: [x_test, y_test])
|
data/examples/cifar10_example.rb
CHANGED
data/examples/iris_example.rb
CHANGED
data/examples/mnist_example.rb
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
require "dnn"
|
2
|
-
require "dnn/
|
2
|
+
require "dnn/mnist"
|
3
|
+
# If you use numo/linalg then please uncomment out.
|
3
4
|
# require "numo/linalg/autoloader"
|
4
5
|
|
5
6
|
include DNN::Layers
|
@@ -26,11 +27,9 @@ model = Model.new
|
|
26
27
|
model << InputLayer.new(784)
|
27
28
|
|
28
29
|
model << Dense.new(256)
|
29
|
-
model << BatchNormalization.new
|
30
30
|
model << ReLU.new
|
31
31
|
|
32
32
|
model << Dense.new(256)
|
33
|
-
model << BatchNormalization.new
|
34
33
|
model << ReLU.new
|
35
34
|
|
36
35
|
model << Dense.new(10)
|
@@ -0,0 +1,77 @@
|
|
1
|
+
#include <ruby.h>
|
2
|
+
#include <stdint.h>
|
3
|
+
#include <stdlib.h>
|
4
|
+
|
5
|
+
#define CIFAR_WIDTH 32
|
6
|
+
#define CIFAR_HEIGHT 32
|
7
|
+
#define CIFAR_CHANNEL 3
|
8
|
+
|
9
|
+
static VALUE cifar10_load_binary(VALUE self, VALUE rb_bin, VALUE rb_num_datas) {
|
10
|
+
uint8_t* bin = (uint8_t*)StringValuePtr(rb_bin);
|
11
|
+
int32_t num_datas = FIX2INT(rb_num_datas);
|
12
|
+
VALUE rb_x_bin;
|
13
|
+
VALUE rb_y_bin;
|
14
|
+
int32_t i;
|
15
|
+
int32_t j = 0;
|
16
|
+
int32_t k = 0;
|
17
|
+
int32_t size = CIFAR_WIDTH * CIFAR_HEIGHT * CIFAR_CHANNEL;
|
18
|
+
int32_t x_bin_size = num_datas * size;
|
19
|
+
int32_t y_bin_size = num_datas;
|
20
|
+
uint8_t* x_bin;
|
21
|
+
uint8_t* y_bin;
|
22
|
+
|
23
|
+
x_bin = (uint8_t*)malloc(x_bin_size);
|
24
|
+
y_bin = (uint8_t*)malloc(y_bin_size);
|
25
|
+
for (i = 0; i < num_datas; i++) {
|
26
|
+
y_bin[i] = bin[j];
|
27
|
+
j++;
|
28
|
+
memcpy(&x_bin[k], &bin[j], size);
|
29
|
+
j += size;
|
30
|
+
k += size;
|
31
|
+
}
|
32
|
+
rb_x_bin = rb_str_new((char*)x_bin, x_bin_size);
|
33
|
+
rb_y_bin = rb_str_new((char*)y_bin, y_bin_size);
|
34
|
+
free(x_bin);
|
35
|
+
free(y_bin);
|
36
|
+
return rb_ary_new3(2, rb_x_bin, rb_y_bin);
|
37
|
+
}
|
38
|
+
|
39
|
+
static VALUE cifar100_load_binary(VALUE self, VALUE rb_bin, VALUE rb_num_datas) {
|
40
|
+
uint8_t* bin = (uint8_t*)StringValuePtr(rb_bin);
|
41
|
+
int32_t num_datas = FIX2INT(rb_num_datas);
|
42
|
+
VALUE rb_x_bin;
|
43
|
+
VALUE rb_y_bin;
|
44
|
+
int32_t i;
|
45
|
+
int32_t j = 0;
|
46
|
+
int32_t k = 0;
|
47
|
+
int32_t size = CIFAR_WIDTH * CIFAR_HEIGHT * CIFAR_CHANNEL;
|
48
|
+
int32_t x_bin_size = num_datas * size;
|
49
|
+
int32_t y_bin_size = num_datas * 2;
|
50
|
+
uint8_t* x_bin;
|
51
|
+
uint8_t* y_bin;
|
52
|
+
|
53
|
+
x_bin = (uint8_t*)malloc(x_bin_size);
|
54
|
+
y_bin = (uint8_t*)malloc(y_bin_size);
|
55
|
+
for (i = 0; i < num_datas; i++) {
|
56
|
+
y_bin[i * 2] = bin[j];
|
57
|
+
y_bin[i * 2 + 1] = bin[j + 1];
|
58
|
+
j += 2;
|
59
|
+
memcpy(&x_bin[k], &bin[j], size);
|
60
|
+
j += size;
|
61
|
+
k += size;
|
62
|
+
}
|
63
|
+
rb_x_bin = rb_str_new((char*)x_bin, x_bin_size);
|
64
|
+
rb_y_bin = rb_str_new((char*)y_bin, y_bin_size);
|
65
|
+
free(x_bin);
|
66
|
+
free(y_bin);
|
67
|
+
return rb_ary_new3(2, rb_x_bin, rb_y_bin);
|
68
|
+
}
|
69
|
+
|
70
|
+
void Init_cifar_loader() {
|
71
|
+
VALUE rb_dnn = rb_define_module("DNN");
|
72
|
+
VALUE rb_cifar10 = rb_define_module_under(rb_dnn, "CIFAR10");
|
73
|
+
VALUE rb_cifar100 = rb_define_module_under(rb_dnn, "CIFAR100");
|
74
|
+
|
75
|
+
rb_define_singleton_method(rb_cifar10, "load_binary", cifar10_load_binary, 2);
|
76
|
+
rb_define_singleton_method(rb_cifar100, "load_binary", cifar100_load_binary, 2);
|
77
|
+
}
|
data/lib/dnn.rb
CHANGED
@@ -16,6 +16,7 @@ require_relative "dnn/core/param"
|
|
16
16
|
require_relative "dnn/core/dataset"
|
17
17
|
require_relative "dnn/core/initializers"
|
18
18
|
require_relative "dnn/core/layers"
|
19
|
+
require_relative "dnn/core/normalizations"
|
19
20
|
require_relative "dnn/core/activations"
|
20
21
|
require_relative "dnn/core/losses"
|
21
22
|
require_relative "dnn/core/regularizers"
|
@@ -1,24 +1,22 @@
|
|
1
1
|
require "zlib"
|
2
2
|
require "archive/tar/minitar"
|
3
|
-
require_relative "
|
3
|
+
require_relative "../../ext/cifar_loader/cifar_loader"
|
4
4
|
require_relative "downloader"
|
5
5
|
|
6
6
|
URL_CIFAR10 = "https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz"
|
7
|
-
|
7
|
+
DIR_CIFAR10 = "cifar-10-batches-bin"
|
8
8
|
|
9
9
|
module DNN
|
10
10
|
module CIFAR10
|
11
11
|
class DNN_CIFAR10_LoadError < DNN_Error; end
|
12
12
|
|
13
|
-
private_class_method :load_binary
|
14
|
-
|
15
13
|
def self.downloads
|
16
|
-
return if Dir.exist?(__dir__ + "/" +
|
14
|
+
return if Dir.exist?(__dir__ + "/downloads/" + DIR_CIFAR10)
|
17
15
|
Downloader.download(URL_CIFAR10)
|
18
|
-
cifar10_binary_file_name = __dir__ + "/" + URL_CIFAR10.match(%r`.+/(.+)`)[1]
|
16
|
+
cifar10_binary_file_name = __dir__ + "/downloads/" + URL_CIFAR10.match(%r`.+/(.+)`)[1]
|
19
17
|
begin
|
20
18
|
Zlib::GzipReader.open(cifar10_binary_file_name) do |gz|
|
21
|
-
Archive::Tar::Minitar::unpack(gz, __dir__)
|
19
|
+
Archive::Tar::Minitar::unpack(gz, __dir__ + "/downloads")
|
22
20
|
end
|
23
21
|
ensure
|
24
22
|
File.unlink(cifar10_binary_file_name)
|
@@ -29,11 +27,11 @@ module DNN
|
|
29
27
|
downloads
|
30
28
|
bin = ""
|
31
29
|
(1..5).each do |i|
|
32
|
-
fname = __dir__ + "/#{
|
30
|
+
fname = __dir__ + "/downloads/#{DIR_CIFAR10}/data_batch_#{i}.bin"
|
33
31
|
raise DNN_CIFAR10_LoadError.new(%`file "#{fname}" is not found.`) unless File.exist?(fname)
|
34
32
|
bin << File.binread(fname)
|
35
33
|
end
|
36
|
-
x_bin, y_bin = load_binary(bin, 50000)
|
34
|
+
x_bin, y_bin = CIFAR10.load_binary(bin, 50000)
|
37
35
|
x_train = Numo::UInt8.from_binary(x_bin).reshape(50000, 3, 32, 32).transpose(0, 2, 3, 1).clone
|
38
36
|
y_train = Numo::UInt8.from_binary(y_bin)
|
39
37
|
[x_train, y_train]
|
@@ -41,10 +39,10 @@ module DNN
|
|
41
39
|
|
42
40
|
def self.load_test
|
43
41
|
downloads
|
44
|
-
fname = __dir__ + "/#{
|
42
|
+
fname = __dir__ + "/downloads/#{DIR_CIFAR10}/test_batch.bin"
|
45
43
|
raise DNN_CIFAR10_LoadError.new(%`file "#{fname}" is not found.`) unless File.exist?(fname)
|
46
44
|
bin = File.binread(fname)
|
47
|
-
x_bin, y_bin = load_binary(bin, 10000)
|
45
|
+
x_bin, y_bin = CIFAR10.load_binary(bin, 10000)
|
48
46
|
x_test = Numo::UInt8.from_binary(x_bin).reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).clone
|
49
47
|
y_test = Numo::UInt8.from_binary(y_bin)
|
50
48
|
[x_test, y_test]
|
data/lib/dnn/cifar100.rb
ADDED
@@ -0,0 +1,49 @@
|
|
1
|
+
require "zlib"
|
2
|
+
require "archive/tar/minitar"
|
3
|
+
require_relative "../../ext/cifar_loader/cifar_loader"
|
4
|
+
require_relative "downloader"
|
5
|
+
|
6
|
+
URL_CIFAR100 = "https://www.cs.toronto.edu/~kriz/cifar-100-binary.tar.gz"
|
7
|
+
DIR_CIFAR100 = "cifar-100-binary"
|
8
|
+
|
9
|
+
module DNN
|
10
|
+
module CIFAR100
|
11
|
+
class DNN_CIFAR100_LoadError < DNN_Error; end
|
12
|
+
|
13
|
+
def self.downloads
|
14
|
+
return if Dir.exist?(__dir__ + "/downloads/" + DIR_CIFAR100)
|
15
|
+
Downloader.download(URL_CIFAR100)
|
16
|
+
cifar100_binary_file_name = __dir__ + "/downloads/" + URL_CIFAR100.match(%r`.+/(.+)`)[1]
|
17
|
+
begin
|
18
|
+
Zlib::GzipReader.open(cifar100_binary_file_name) do |gz|
|
19
|
+
Archive::Tar::Minitar::unpack(gz, __dir__ + "/downloads")
|
20
|
+
end
|
21
|
+
ensure
|
22
|
+
File.unlink(cifar100_binary_file_name)
|
23
|
+
end
|
24
|
+
end
|
25
|
+
|
26
|
+
def self.load_train
|
27
|
+
downloads
|
28
|
+
bin = ""
|
29
|
+
fname = __dir__ + "/downloads/#{DIR_CIFAR100}/train.bin"
|
30
|
+
raise DNN_CIFAR100_LoadError.new(%`file "#{fname}" is not found.`) unless File.exist?(fname)
|
31
|
+
bin << File.binread(fname)
|
32
|
+
x_bin, y_bin = CIFAR100.load_binary(bin, 50000)
|
33
|
+
x_train = Numo::UInt8.from_binary(x_bin).reshape(50000, 3, 32, 32).transpose(0, 2, 3, 1).clone
|
34
|
+
y_train = Numo::UInt8.from_binary(y_bin).reshape(50000, 2)
|
35
|
+
[x_train, y_train]
|
36
|
+
end
|
37
|
+
|
38
|
+
def self.load_test
|
39
|
+
downloads
|
40
|
+
fname = __dir__ + "/downloads/#{DIR_CIFAR100}/test.bin"
|
41
|
+
raise DNN_CIFAR100_LoadError.new(%`file "#{fname}" is not found.`) unless File.exist?(fname)
|
42
|
+
bin = File.binread(fname)
|
43
|
+
x_bin, y_bin = CIFAR100.load_binary(bin, 10000)
|
44
|
+
x_test = Numo::UInt8.from_binary(x_bin).reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).clone
|
45
|
+
y_test = Numo::UInt8.from_binary(y_bin).reshape(10000, 2)
|
46
|
+
[x_test, y_test]
|
47
|
+
end
|
48
|
+
end
|
49
|
+
end
|
data/lib/dnn/core/activations.rb
CHANGED
@@ -3,22 +3,22 @@ module DNN
|
|
3
3
|
|
4
4
|
class Sigmoid < Layers::Layer
|
5
5
|
def forward(x)
|
6
|
-
@
|
6
|
+
@y = 1 / (1 + NMath.exp(-x))
|
7
7
|
end
|
8
8
|
|
9
|
-
def backward(
|
10
|
-
|
9
|
+
def backward(dy)
|
10
|
+
dy * (1 - @y) * @y
|
11
11
|
end
|
12
12
|
end
|
13
13
|
|
14
14
|
|
15
15
|
class Tanh < Layers::Layer
|
16
16
|
def forward(x)
|
17
|
-
@
|
17
|
+
@y = NMath.tanh(x)
|
18
18
|
end
|
19
19
|
|
20
|
-
def backward(
|
21
|
-
|
20
|
+
def backward(dy)
|
21
|
+
dy * (1 - @y**2)
|
22
22
|
end
|
23
23
|
end
|
24
24
|
|
@@ -29,8 +29,8 @@ module DNN
|
|
29
29
|
x / (1 + x.abs)
|
30
30
|
end
|
31
31
|
|
32
|
-
def backward(
|
33
|
-
|
32
|
+
def backward(dy)
|
33
|
+
dy * (1 / (1 + @x.abs)**2)
|
34
34
|
end
|
35
35
|
end
|
36
36
|
|
@@ -41,8 +41,8 @@ module DNN
|
|
41
41
|
NMath.log(1 + NMath.exp(x))
|
42
42
|
end
|
43
43
|
|
44
|
-
def backward(
|
45
|
-
|
44
|
+
def backward(dy)
|
45
|
+
dy * (1 / (1 + NMath.exp(-@x)))
|
46
46
|
end
|
47
47
|
end
|
48
48
|
|
@@ -50,11 +50,11 @@ module DNN
|
|
50
50
|
class Swish < Layers::Layer
|
51
51
|
def forward(x)
|
52
52
|
@x = x
|
53
|
-
@
|
53
|
+
@y = x * (1 / (1 + NMath.exp(-x)))
|
54
54
|
end
|
55
55
|
|
56
|
-
def backward(
|
57
|
-
|
56
|
+
def backward(dy)
|
57
|
+
dy * (@y + (1 / (1 + NMath.exp(-@x))) * (1 - @y))
|
58
58
|
end
|
59
59
|
end
|
60
60
|
|
@@ -66,23 +66,25 @@ module DNN
|
|
66
66
|
x
|
67
67
|
end
|
68
68
|
|
69
|
-
def backward(
|
69
|
+
def backward(dy)
|
70
70
|
@x[@x > 0] = 1
|
71
71
|
@x[@x <= 0] = 0
|
72
|
-
|
72
|
+
dy * @x
|
73
73
|
end
|
74
74
|
end
|
75
75
|
|
76
76
|
|
77
77
|
class LeakyReLU < Layers::Layer
|
78
|
+
# @return [Float] Return the alpha value.
|
78
79
|
attr_reader :alpha
|
79
80
|
|
80
|
-
def
|
81
|
-
|
81
|
+
def self.from_hash(hash)
|
82
|
+
self.new(hash[:alpha])
|
82
83
|
end
|
83
84
|
|
84
|
-
|
85
|
-
|
85
|
+
# @param [Float] alpha The slope when the output value is negative.
|
86
|
+
def initialize(alpha = 0.3)
|
87
|
+
@alpha = alpha
|
86
88
|
end
|
87
89
|
|
88
90
|
def forward(x)
|
@@ -92,10 +94,10 @@ module DNN
|
|
92
94
|
x * a
|
93
95
|
end
|
94
96
|
|
95
|
-
def backward(
|
97
|
+
def backward(dy)
|
96
98
|
@x[@x > 0] = 1
|
97
99
|
@x[@x <= 0] = @alpha
|
98
|
-
|
100
|
+
dy * @x
|
99
101
|
end
|
100
102
|
|
101
103
|
def to_hash
|
@@ -105,12 +107,14 @@ module DNN
|
|
105
107
|
|
106
108
|
|
107
109
|
class ELU < Layers::Layer
|
110
|
+
# @return [Float] Return the alpha value.
|
108
111
|
attr_reader :alpha
|
109
112
|
|
110
|
-
def self.
|
113
|
+
def self.from_hash(hash)
|
111
114
|
self.new(hash[:alpha])
|
112
115
|
end
|
113
116
|
|
117
|
+
# @param [Float] alpha The slope when the output value is negative.
|
114
118
|
def initialize(alpha = 1.0)
|
115
119
|
@alpha = alpha
|
116
120
|
end
|
@@ -126,13 +130,13 @@ module DNN
|
|
126
130
|
x1 + x2
|
127
131
|
end
|
128
132
|
|
129
|
-
def backward(
|
133
|
+
def backward(dy)
|
130
134
|
dx = Xumo::SFloat.ones(@x.shape)
|
131
135
|
dx[@x < 0] = 0
|
132
136
|
dx2 = Xumo::SFloat.zeros(@x.shape)
|
133
137
|
dx2[@x < 0] = 1
|
134
138
|
dx2 *= @alpha * NMath.exp(@x)
|
135
|
-
|
139
|
+
dy * (dx + dx2)
|
136
140
|
end
|
137
141
|
|
138
142
|
def to_hash
|