ruby-dnn 0.9.4 → 0.10.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +39 -3
- data/Rakefile +6 -0
- data/examples/cifar100_example.rb +71 -0
- data/examples/cifar10_example.rb +2 -1
- data/examples/iris_example.rb +2 -1
- data/examples/mnist_conv2d_example.rb +2 -1
- data/examples/mnist_example.rb +2 -3
- data/examples/mnist_lstm_example.rb +2 -1
- data/ext/cifar_loader/cifar_loader.c +77 -0
- data/ext/cifar_loader/extconf.rb +3 -0
- data/lib/dnn.rb +1 -0
- data/lib/dnn/{lib/cifar10.rb → cifar10.rb} +9 -11
- data/lib/dnn/cifar100.rb +49 -0
- data/lib/dnn/core/activations.rb +28 -24
- data/lib/dnn/core/cnn_layers.rb +216 -94
- data/lib/dnn/core/dataset.rb +21 -5
- data/lib/dnn/core/initializers.rb +3 -3
- data/lib/dnn/core/layers.rb +81 -150
- data/lib/dnn/core/losses.rb +88 -49
- data/lib/dnn/core/model.rb +97 -74
- data/lib/dnn/core/normalizations.rb +72 -0
- data/lib/dnn/core/optimizers.rb +171 -78
- data/lib/dnn/core/regularizers.rb +92 -22
- data/lib/dnn/core/rnn_layers.rb +146 -121
- data/lib/dnn/core/utils.rb +4 -3
- data/lib/dnn/{lib/downloader.rb → downloader.rb} +5 -1
- data/lib/dnn/{lib/image.rb → image.rb} +1 -1
- data/lib/dnn/{lib/iris.rb → iris.rb} +1 -1
- data/lib/dnn/{lib/mnist.rb → mnist.rb} +4 -3
- data/lib/dnn/version.rb +1 -1
- data/ruby-dnn.gemspec +1 -1
- metadata +13 -12
- data/API-Reference.ja.md +0 -978
- data/LIB-API-Reference.ja.md +0 -97
- data/ext/cifar10_loader/cifar10_loader.c +0 -44
- data/ext/cifar10_loader/extconf.rb +0 -3
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 9fbaca2827a3877b9c215d3a93b7c0e20a4a456471541e3e3f0140a95abf695b
|
4
|
+
data.tar.gz: 0e57474b603a0ba4d08d66176fea34b77e9dbc0c51639635bb861779a5bd047e
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: e02bcedc98d20dda9c4aae46955af777586cb433ca1d4f479248e56b2ec08a38fbb1b1023c644620664c40fda420a4584d98c1dd0080ab29212c0a0176b669d7
|
7
|
+
data.tar.gz: bb1486c4a85edbb5b5aed8e8feaffd15937893883d5d31b831427e7bcf663e4157452dd565f0bf4909b444e1241213f8f0a682f56d03d267954e66b9d312e7ab
|
data/README.md
CHANGED
@@ -1,4 +1,5 @@
|
|
1
1
|
# ruby-dnn
|
2
|
+
[](https://badge.fury.io/rb/ruby-dnn)
|
2
3
|
|
3
4
|
ruby-dnn is a ruby deep learning library. This library supports full connected neural network and convolution neural network.
|
4
5
|
Currently, you can get 99% accuracy with MNIST and 74% with CIFAR 10.
|
@@ -21,10 +22,45 @@ Or install it yourself as:
|
|
21
22
|
|
22
23
|
## Usage
|
23
24
|
|
24
|
-
|
25
|
-
|
25
|
+
### MNIST MLP example
|
26
|
+
|
27
|
+
```ruby
|
28
|
+
model = Model.new
|
29
|
+
|
30
|
+
model << InputLayer.new(784)
|
31
|
+
|
32
|
+
model << Dense.new(256)
|
33
|
+
model << ReLU.new
|
34
|
+
|
35
|
+
model << Dense.new(256)
|
36
|
+
model << ReLU.new
|
37
|
+
|
38
|
+
model << Dense.new(10)
|
39
|
+
|
40
|
+
model.compile(RMSProp.new, SoftmaxCrossEntropy.new)
|
41
|
+
|
42
|
+
model.train(x_train, y_train, 10, batch_size: 100, test: [x_test, y_test])
|
43
|
+
|
44
|
+
```
|
45
|
+
|
46
|
+
Please refer to examples for basic usage.
|
26
47
|
If you want to know more detailed information, please refer to the source code.
|
27
48
|
|
49
|
+
## Implemented
|
50
|
+
|| Implemented classes |
|
51
|
+
|:-----------|------------:|
|
52
|
+
| Connections | Dense, Conv2D, Conv2D_Transpose, SimpleRNN, LSTM, GRU |
|
53
|
+
| Layers | Flatten, Reshape, Dropout, BatchNormalization, MaxPool2D, AvgPool2D, UnPool2D |
|
54
|
+
| Activations | Sigmoid, Tanh, Softsign, Softplus, Swish, ReLU, LeakyReLU, ELU |
|
55
|
+
| Optimizers | SGD, Nesterov, AdaGrad, RMSProp, AdaDelta, Adam, RMSPropGraves |
|
56
|
+
| Losses | MeanSquaredError, MeanAbsoluteError, HuberLoss, SoftmaxCrossEntropy, SigmoidCrossEntropy |
|
57
|
+
|
58
|
+
## TODO
|
59
|
+
● Add CI badge.
|
60
|
+
● Write a test.
|
61
|
+
● Write a document.
|
62
|
+
● Support to GPU.
|
63
|
+
|
28
64
|
## Development
|
29
65
|
|
30
66
|
After checking out the repo, run `bin/setup` to install dependencies. Then, run `rake "spec"` to run the tests. You can also run `bin/console` for an interactive prompt that will allow you to experiment.
|
@@ -41,4 +77,4 @@ The gem is available as open source under the terms of the [MIT License](https:/
|
|
41
77
|
|
42
78
|
## Code of Conduct
|
43
79
|
|
44
|
-
Everyone interacting in the
|
80
|
+
Everyone interacting in the ruby-dnn project’s codebases, issue trackers, chat rooms and mailing lists is expected to follow the [code of conduct](https://github.com/[USERNAME]/dnn/blob/master/CODE_OF_CONDUCT.md).
|
data/Rakefile
CHANGED
@@ -0,0 +1,71 @@
|
|
1
|
+
require "dnn"
|
2
|
+
require "dnn/cifar100"
|
3
|
+
# If you use numo/linalg then please uncomment out.
|
4
|
+
# require "numo/linalg/autoloader"
|
5
|
+
|
6
|
+
include DNN::Layers
|
7
|
+
include DNN::Activations
|
8
|
+
include DNN::Optimizers
|
9
|
+
include DNN::Losses
|
10
|
+
Model = DNN::Model
|
11
|
+
CIFAR100 = DNN::CIFAR100
|
12
|
+
|
13
|
+
x_train, y_train = CIFAR100.load_train
|
14
|
+
x_test, y_test = CIFAR100.load_test
|
15
|
+
|
16
|
+
x_train = Numo::SFloat.cast(x_train)
|
17
|
+
x_test = Numo::SFloat.cast(x_test)
|
18
|
+
|
19
|
+
x_train /= 255
|
20
|
+
x_test /= 255
|
21
|
+
|
22
|
+
y_train = y_train[true, 1]
|
23
|
+
y_test = y_test[true, 1]
|
24
|
+
|
25
|
+
y_train = DNN::Utils.to_categorical(y_train, 100, Numo::SFloat)
|
26
|
+
y_test = DNN::Utils.to_categorical(y_test, 100, Numo::SFloat)
|
27
|
+
|
28
|
+
model = Model.new
|
29
|
+
|
30
|
+
model << InputLayer.new([32, 32, 3])
|
31
|
+
|
32
|
+
model << Conv2D.new(16, 5, padding: true)
|
33
|
+
model << BatchNormalization.new
|
34
|
+
model << ReLU.new
|
35
|
+
|
36
|
+
model << Conv2D.new(16, 5, padding: true)
|
37
|
+
model << BatchNormalization.new
|
38
|
+
model << ReLU.new
|
39
|
+
|
40
|
+
model << MaxPool2D.new(2)
|
41
|
+
|
42
|
+
model << Conv2D.new(32, 5, padding: true)
|
43
|
+
model << BatchNormalization.new
|
44
|
+
model << ReLU.new
|
45
|
+
|
46
|
+
model << Conv2D.new(32, 5, padding: true)
|
47
|
+
model << BatchNormalization.new
|
48
|
+
model << ReLU.new
|
49
|
+
|
50
|
+
model << MaxPool2D.new(2)
|
51
|
+
|
52
|
+
model << Conv2D.new(64, 5, padding: true)
|
53
|
+
model << BatchNormalization.new
|
54
|
+
model << ReLU.new
|
55
|
+
|
56
|
+
model << Conv2D.new(64, 5, padding: true)
|
57
|
+
model << BatchNormalization.new
|
58
|
+
model << ReLU.new
|
59
|
+
|
60
|
+
model << Flatten.new
|
61
|
+
|
62
|
+
model << Dense.new(1024)
|
63
|
+
model << BatchNormalization.new
|
64
|
+
model << ReLU.new
|
65
|
+
model << Dropout.new(0.5)
|
66
|
+
|
67
|
+
model << Dense.new(100)
|
68
|
+
|
69
|
+
model.compile(Adam.new, SoftmaxCrossEntropy.new)
|
70
|
+
|
71
|
+
model.train(x_train, y_train, 10, batch_size: 100, test: [x_test, y_test])
|
data/examples/cifar10_example.rb
CHANGED
data/examples/iris_example.rb
CHANGED
data/examples/mnist_example.rb
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
require "dnn"
|
2
|
-
require "dnn/
|
2
|
+
require "dnn/mnist"
|
3
|
+
# If you use numo/linalg then please uncomment out.
|
3
4
|
# require "numo/linalg/autoloader"
|
4
5
|
|
5
6
|
include DNN::Layers
|
@@ -26,11 +27,9 @@ model = Model.new
|
|
26
27
|
model << InputLayer.new(784)
|
27
28
|
|
28
29
|
model << Dense.new(256)
|
29
|
-
model << BatchNormalization.new
|
30
30
|
model << ReLU.new
|
31
31
|
|
32
32
|
model << Dense.new(256)
|
33
|
-
model << BatchNormalization.new
|
34
33
|
model << ReLU.new
|
35
34
|
|
36
35
|
model << Dense.new(10)
|
@@ -0,0 +1,77 @@
|
|
1
|
+
#include <ruby.h>
|
2
|
+
#include <stdint.h>
|
3
|
+
#include <stdlib.h>
|
4
|
+
|
5
|
+
#define CIFAR_WIDTH 32
|
6
|
+
#define CIFAR_HEIGHT 32
|
7
|
+
#define CIFAR_CHANNEL 3
|
8
|
+
|
9
|
+
static VALUE cifar10_load_binary(VALUE self, VALUE rb_bin, VALUE rb_num_datas) {
|
10
|
+
uint8_t* bin = (uint8_t*)StringValuePtr(rb_bin);
|
11
|
+
int32_t num_datas = FIX2INT(rb_num_datas);
|
12
|
+
VALUE rb_x_bin;
|
13
|
+
VALUE rb_y_bin;
|
14
|
+
int32_t i;
|
15
|
+
int32_t j = 0;
|
16
|
+
int32_t k = 0;
|
17
|
+
int32_t size = CIFAR_WIDTH * CIFAR_HEIGHT * CIFAR_CHANNEL;
|
18
|
+
int32_t x_bin_size = num_datas * size;
|
19
|
+
int32_t y_bin_size = num_datas;
|
20
|
+
uint8_t* x_bin;
|
21
|
+
uint8_t* y_bin;
|
22
|
+
|
23
|
+
x_bin = (uint8_t*)malloc(x_bin_size);
|
24
|
+
y_bin = (uint8_t*)malloc(y_bin_size);
|
25
|
+
for (i = 0; i < num_datas; i++) {
|
26
|
+
y_bin[i] = bin[j];
|
27
|
+
j++;
|
28
|
+
memcpy(&x_bin[k], &bin[j], size);
|
29
|
+
j += size;
|
30
|
+
k += size;
|
31
|
+
}
|
32
|
+
rb_x_bin = rb_str_new((char*)x_bin, x_bin_size);
|
33
|
+
rb_y_bin = rb_str_new((char*)y_bin, y_bin_size);
|
34
|
+
free(x_bin);
|
35
|
+
free(y_bin);
|
36
|
+
return rb_ary_new3(2, rb_x_bin, rb_y_bin);
|
37
|
+
}
|
38
|
+
|
39
|
+
static VALUE cifar100_load_binary(VALUE self, VALUE rb_bin, VALUE rb_num_datas) {
|
40
|
+
uint8_t* bin = (uint8_t*)StringValuePtr(rb_bin);
|
41
|
+
int32_t num_datas = FIX2INT(rb_num_datas);
|
42
|
+
VALUE rb_x_bin;
|
43
|
+
VALUE rb_y_bin;
|
44
|
+
int32_t i;
|
45
|
+
int32_t j = 0;
|
46
|
+
int32_t k = 0;
|
47
|
+
int32_t size = CIFAR_WIDTH * CIFAR_HEIGHT * CIFAR_CHANNEL;
|
48
|
+
int32_t x_bin_size = num_datas * size;
|
49
|
+
int32_t y_bin_size = num_datas * 2;
|
50
|
+
uint8_t* x_bin;
|
51
|
+
uint8_t* y_bin;
|
52
|
+
|
53
|
+
x_bin = (uint8_t*)malloc(x_bin_size);
|
54
|
+
y_bin = (uint8_t*)malloc(y_bin_size);
|
55
|
+
for (i = 0; i < num_datas; i++) {
|
56
|
+
y_bin[i * 2] = bin[j];
|
57
|
+
y_bin[i * 2 + 1] = bin[j + 1];
|
58
|
+
j += 2;
|
59
|
+
memcpy(&x_bin[k], &bin[j], size);
|
60
|
+
j += size;
|
61
|
+
k += size;
|
62
|
+
}
|
63
|
+
rb_x_bin = rb_str_new((char*)x_bin, x_bin_size);
|
64
|
+
rb_y_bin = rb_str_new((char*)y_bin, y_bin_size);
|
65
|
+
free(x_bin);
|
66
|
+
free(y_bin);
|
67
|
+
return rb_ary_new3(2, rb_x_bin, rb_y_bin);
|
68
|
+
}
|
69
|
+
|
70
|
+
void Init_cifar_loader() {
|
71
|
+
VALUE rb_dnn = rb_define_module("DNN");
|
72
|
+
VALUE rb_cifar10 = rb_define_module_under(rb_dnn, "CIFAR10");
|
73
|
+
VALUE rb_cifar100 = rb_define_module_under(rb_dnn, "CIFAR100");
|
74
|
+
|
75
|
+
rb_define_singleton_method(rb_cifar10, "load_binary", cifar10_load_binary, 2);
|
76
|
+
rb_define_singleton_method(rb_cifar100, "load_binary", cifar100_load_binary, 2);
|
77
|
+
}
|
data/lib/dnn.rb
CHANGED
@@ -16,6 +16,7 @@ require_relative "dnn/core/param"
|
|
16
16
|
require_relative "dnn/core/dataset"
|
17
17
|
require_relative "dnn/core/initializers"
|
18
18
|
require_relative "dnn/core/layers"
|
19
|
+
require_relative "dnn/core/normalizations"
|
19
20
|
require_relative "dnn/core/activations"
|
20
21
|
require_relative "dnn/core/losses"
|
21
22
|
require_relative "dnn/core/regularizers"
|
@@ -1,24 +1,22 @@
|
|
1
1
|
require "zlib"
|
2
2
|
require "archive/tar/minitar"
|
3
|
-
require_relative "
|
3
|
+
require_relative "../../ext/cifar_loader/cifar_loader"
|
4
4
|
require_relative "downloader"
|
5
5
|
|
6
6
|
URL_CIFAR10 = "https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz"
|
7
|
-
|
7
|
+
DIR_CIFAR10 = "cifar-10-batches-bin"
|
8
8
|
|
9
9
|
module DNN
|
10
10
|
module CIFAR10
|
11
11
|
class DNN_CIFAR10_LoadError < DNN_Error; end
|
12
12
|
|
13
|
-
private_class_method :load_binary
|
14
|
-
|
15
13
|
def self.downloads
|
16
|
-
return if Dir.exist?(__dir__ + "/" +
|
14
|
+
return if Dir.exist?(__dir__ + "/downloads/" + DIR_CIFAR10)
|
17
15
|
Downloader.download(URL_CIFAR10)
|
18
|
-
cifar10_binary_file_name = __dir__ + "/" + URL_CIFAR10.match(%r`.+/(.+)`)[1]
|
16
|
+
cifar10_binary_file_name = __dir__ + "/downloads/" + URL_CIFAR10.match(%r`.+/(.+)`)[1]
|
19
17
|
begin
|
20
18
|
Zlib::GzipReader.open(cifar10_binary_file_name) do |gz|
|
21
|
-
Archive::Tar::Minitar::unpack(gz, __dir__)
|
19
|
+
Archive::Tar::Minitar::unpack(gz, __dir__ + "/downloads")
|
22
20
|
end
|
23
21
|
ensure
|
24
22
|
File.unlink(cifar10_binary_file_name)
|
@@ -29,11 +27,11 @@ module DNN
|
|
29
27
|
downloads
|
30
28
|
bin = ""
|
31
29
|
(1..5).each do |i|
|
32
|
-
fname = __dir__ + "/#{
|
30
|
+
fname = __dir__ + "/downloads/#{DIR_CIFAR10}/data_batch_#{i}.bin"
|
33
31
|
raise DNN_CIFAR10_LoadError.new(%`file "#{fname}" is not found.`) unless File.exist?(fname)
|
34
32
|
bin << File.binread(fname)
|
35
33
|
end
|
36
|
-
x_bin, y_bin = load_binary(bin, 50000)
|
34
|
+
x_bin, y_bin = CIFAR10.load_binary(bin, 50000)
|
37
35
|
x_train = Numo::UInt8.from_binary(x_bin).reshape(50000, 3, 32, 32).transpose(0, 2, 3, 1).clone
|
38
36
|
y_train = Numo::UInt8.from_binary(y_bin)
|
39
37
|
[x_train, y_train]
|
@@ -41,10 +39,10 @@ module DNN
|
|
41
39
|
|
42
40
|
def self.load_test
|
43
41
|
downloads
|
44
|
-
fname = __dir__ + "/#{
|
42
|
+
fname = __dir__ + "/downloads/#{DIR_CIFAR10}/test_batch.bin"
|
45
43
|
raise DNN_CIFAR10_LoadError.new(%`file "#{fname}" is not found.`) unless File.exist?(fname)
|
46
44
|
bin = File.binread(fname)
|
47
|
-
x_bin, y_bin = load_binary(bin, 10000)
|
45
|
+
x_bin, y_bin = CIFAR10.load_binary(bin, 10000)
|
48
46
|
x_test = Numo::UInt8.from_binary(x_bin).reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).clone
|
49
47
|
y_test = Numo::UInt8.from_binary(y_bin)
|
50
48
|
[x_test, y_test]
|
data/lib/dnn/cifar100.rb
ADDED
@@ -0,0 +1,49 @@
|
|
1
|
+
require "zlib"
|
2
|
+
require "archive/tar/minitar"
|
3
|
+
require_relative "../../ext/cifar_loader/cifar_loader"
|
4
|
+
require_relative "downloader"
|
5
|
+
|
6
|
+
URL_CIFAR100 = "https://www.cs.toronto.edu/~kriz/cifar-100-binary.tar.gz"
|
7
|
+
DIR_CIFAR100 = "cifar-100-binary"
|
8
|
+
|
9
|
+
module DNN
|
10
|
+
module CIFAR100
|
11
|
+
class DNN_CIFAR100_LoadError < DNN_Error; end
|
12
|
+
|
13
|
+
def self.downloads
|
14
|
+
return if Dir.exist?(__dir__ + "/downloads/" + DIR_CIFAR100)
|
15
|
+
Downloader.download(URL_CIFAR100)
|
16
|
+
cifar100_binary_file_name = __dir__ + "/downloads/" + URL_CIFAR100.match(%r`.+/(.+)`)[1]
|
17
|
+
begin
|
18
|
+
Zlib::GzipReader.open(cifar100_binary_file_name) do |gz|
|
19
|
+
Archive::Tar::Minitar::unpack(gz, __dir__ + "/downloads")
|
20
|
+
end
|
21
|
+
ensure
|
22
|
+
File.unlink(cifar100_binary_file_name)
|
23
|
+
end
|
24
|
+
end
|
25
|
+
|
26
|
+
def self.load_train
|
27
|
+
downloads
|
28
|
+
bin = ""
|
29
|
+
fname = __dir__ + "/downloads/#{DIR_CIFAR100}/train.bin"
|
30
|
+
raise DNN_CIFAR100_LoadError.new(%`file "#{fname}" is not found.`) unless File.exist?(fname)
|
31
|
+
bin << File.binread(fname)
|
32
|
+
x_bin, y_bin = CIFAR100.load_binary(bin, 50000)
|
33
|
+
x_train = Numo::UInt8.from_binary(x_bin).reshape(50000, 3, 32, 32).transpose(0, 2, 3, 1).clone
|
34
|
+
y_train = Numo::UInt8.from_binary(y_bin).reshape(50000, 2)
|
35
|
+
[x_train, y_train]
|
36
|
+
end
|
37
|
+
|
38
|
+
def self.load_test
|
39
|
+
downloads
|
40
|
+
fname = __dir__ + "/downloads/#{DIR_CIFAR100}/test.bin"
|
41
|
+
raise DNN_CIFAR100_LoadError.new(%`file "#{fname}" is not found.`) unless File.exist?(fname)
|
42
|
+
bin = File.binread(fname)
|
43
|
+
x_bin, y_bin = CIFAR100.load_binary(bin, 10000)
|
44
|
+
x_test = Numo::UInt8.from_binary(x_bin).reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).clone
|
45
|
+
y_test = Numo::UInt8.from_binary(y_bin).reshape(10000, 2)
|
46
|
+
[x_test, y_test]
|
47
|
+
end
|
48
|
+
end
|
49
|
+
end
|
data/lib/dnn/core/activations.rb
CHANGED
@@ -3,22 +3,22 @@ module DNN
|
|
3
3
|
|
4
4
|
class Sigmoid < Layers::Layer
|
5
5
|
def forward(x)
|
6
|
-
@
|
6
|
+
@y = 1 / (1 + NMath.exp(-x))
|
7
7
|
end
|
8
8
|
|
9
|
-
def backward(
|
10
|
-
|
9
|
+
def backward(dy)
|
10
|
+
dy * (1 - @y) * @y
|
11
11
|
end
|
12
12
|
end
|
13
13
|
|
14
14
|
|
15
15
|
class Tanh < Layers::Layer
|
16
16
|
def forward(x)
|
17
|
-
@
|
17
|
+
@y = NMath.tanh(x)
|
18
18
|
end
|
19
19
|
|
20
|
-
def backward(
|
21
|
-
|
20
|
+
def backward(dy)
|
21
|
+
dy * (1 - @y**2)
|
22
22
|
end
|
23
23
|
end
|
24
24
|
|
@@ -29,8 +29,8 @@ module DNN
|
|
29
29
|
x / (1 + x.abs)
|
30
30
|
end
|
31
31
|
|
32
|
-
def backward(
|
33
|
-
|
32
|
+
def backward(dy)
|
33
|
+
dy * (1 / (1 + @x.abs)**2)
|
34
34
|
end
|
35
35
|
end
|
36
36
|
|
@@ -41,8 +41,8 @@ module DNN
|
|
41
41
|
NMath.log(1 + NMath.exp(x))
|
42
42
|
end
|
43
43
|
|
44
|
-
def backward(
|
45
|
-
|
44
|
+
def backward(dy)
|
45
|
+
dy * (1 / (1 + NMath.exp(-@x)))
|
46
46
|
end
|
47
47
|
end
|
48
48
|
|
@@ -50,11 +50,11 @@ module DNN
|
|
50
50
|
class Swish < Layers::Layer
|
51
51
|
def forward(x)
|
52
52
|
@x = x
|
53
|
-
@
|
53
|
+
@y = x * (1 / (1 + NMath.exp(-x)))
|
54
54
|
end
|
55
55
|
|
56
|
-
def backward(
|
57
|
-
|
56
|
+
def backward(dy)
|
57
|
+
dy * (@y + (1 / (1 + NMath.exp(-@x))) * (1 - @y))
|
58
58
|
end
|
59
59
|
end
|
60
60
|
|
@@ -66,23 +66,25 @@ module DNN
|
|
66
66
|
x
|
67
67
|
end
|
68
68
|
|
69
|
-
def backward(
|
69
|
+
def backward(dy)
|
70
70
|
@x[@x > 0] = 1
|
71
71
|
@x[@x <= 0] = 0
|
72
|
-
|
72
|
+
dy * @x
|
73
73
|
end
|
74
74
|
end
|
75
75
|
|
76
76
|
|
77
77
|
class LeakyReLU < Layers::Layer
|
78
|
+
# @return [Float] Return the alpha value.
|
78
79
|
attr_reader :alpha
|
79
80
|
|
80
|
-
def
|
81
|
-
|
81
|
+
def self.from_hash(hash)
|
82
|
+
self.new(hash[:alpha])
|
82
83
|
end
|
83
84
|
|
84
|
-
|
85
|
-
|
85
|
+
# @param [Float] alpha The slope when the output value is negative.
|
86
|
+
def initialize(alpha = 0.3)
|
87
|
+
@alpha = alpha
|
86
88
|
end
|
87
89
|
|
88
90
|
def forward(x)
|
@@ -92,10 +94,10 @@ module DNN
|
|
92
94
|
x * a
|
93
95
|
end
|
94
96
|
|
95
|
-
def backward(
|
97
|
+
def backward(dy)
|
96
98
|
@x[@x > 0] = 1
|
97
99
|
@x[@x <= 0] = @alpha
|
98
|
-
|
100
|
+
dy * @x
|
99
101
|
end
|
100
102
|
|
101
103
|
def to_hash
|
@@ -105,12 +107,14 @@ module DNN
|
|
105
107
|
|
106
108
|
|
107
109
|
class ELU < Layers::Layer
|
110
|
+
# @return [Float] Return the alpha value.
|
108
111
|
attr_reader :alpha
|
109
112
|
|
110
|
-
def self.
|
113
|
+
def self.from_hash(hash)
|
111
114
|
self.new(hash[:alpha])
|
112
115
|
end
|
113
116
|
|
117
|
+
# @param [Float] alpha The slope when the output value is negative.
|
114
118
|
def initialize(alpha = 1.0)
|
115
119
|
@alpha = alpha
|
116
120
|
end
|
@@ -126,13 +130,13 @@ module DNN
|
|
126
130
|
x1 + x2
|
127
131
|
end
|
128
132
|
|
129
|
-
def backward(
|
133
|
+
def backward(dy)
|
130
134
|
dx = Xumo::SFloat.ones(@x.shape)
|
131
135
|
dx[@x < 0] = 0
|
132
136
|
dx2 = Xumo::SFloat.zeros(@x.shape)
|
133
137
|
dx2[@x < 0] = 1
|
134
138
|
dx2 *= @alpha * NMath.exp(@x)
|
135
|
-
|
139
|
+
dy * (dx + dx2)
|
136
140
|
end
|
137
141
|
|
138
142
|
def to_hash
|