ruby-dnn 0.16.2 → 1.0.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/README.md +22 -0
- data/examples/api-examples/early_stopping_example.rb +1 -1
- data/examples/api-examples/initializer_example.rb +1 -1
- data/examples/api-examples/regularizer_example.rb +1 -1
- data/examples/dcgan/dcgan.rb +10 -3
- data/examples/pix2pix/dcgan.rb +4 -0
- data/examples/pix2pix/train.rb +5 -2
- data/examples/vae.rb +0 -6
- data/lib/dnn/core/callbacks.rb +7 -3
- data/lib/dnn/core/error.rb +2 -2
- data/lib/dnn/core/initializers.rb +5 -5
- data/lib/dnn/core/iterator.rb +4 -1
- data/lib/dnn/core/layers/basic_layers.rb +42 -65
- data/lib/dnn/core/layers/cnn_layers.rb +34 -35
- data/lib/dnn/core/layers/embedding.rb +3 -24
- data/lib/dnn/core/layers/math_layers.rb +12 -0
- data/lib/dnn/core/layers/merge_layers.rb +13 -13
- data/lib/dnn/core/layers/normalizations.rb +4 -4
- data/lib/dnn/core/layers/rnn_layers.rb +46 -46
- data/lib/dnn/core/link.rb +8 -8
- data/lib/dnn/core/losses.rb +10 -20
- data/lib/dnn/core/models.rb +23 -46
- data/lib/dnn/core/monkey_patch.rb +10 -0
- data/lib/dnn/core/optimizers.rb +1 -2
- data/lib/dnn/core/param.rb +2 -2
- data/lib/dnn/core/regularizers.rb +1 -1
- data/lib/dnn/core/savers.rb +2 -2
- data/lib/dnn/core/tensor.rb +1 -1
- data/lib/dnn/datasets/cifar10.rb +1 -1
- data/lib/dnn/datasets/cifar100.rb +1 -1
- data/lib/dnn/datasets/downloader.rb +1 -1
- data/lib/dnn/datasets/fashion-mnist.rb +1 -1
- data/lib/dnn/datasets/iris.rb +1 -1
- data/lib/dnn/datasets/mnist.rb +1 -1
- data/lib/dnn/datasets/stl-10.rb +2 -2
- data/lib/dnn/version.rb +1 -1
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 0db9ac3047ba8c15d903ace901f5e4e332835d11dffca2f441664ae843049d1d
|
4
|
+
data.tar.gz: f1b4bf61da8a48b8ad483eb806ab443bb40f1b0d88573c2d901ae45299abf86d
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 880fe0688bb5b15c016fdddb15b18f5e0b3ba2a45ae36292182adf8def20d93ca3ae176747dbf3d1369ea28cdaaf23e7cd9e96d0a0c6c4bb92db27131d8f4d93
|
7
|
+
data.tar.gz: 4d00dc6831f0c82e0dc1b4128d98dc391456410cbf26e0d0192b19092333664774eccabd2db0a609b739c1c40664e148dcd69152cdc45394690e21d785e1acc0
|
data/README.md
CHANGED
@@ -42,6 +42,11 @@ model << Dense.new(10)
|
|
42
42
|
model.setup(Adam.new, SoftmaxCrossEntropy.new)
|
43
43
|
|
44
44
|
model.train(x_train, y_train, 10, batch_size: 128, test: [x_test, y_test])
|
45
|
+
|
46
|
+
|
47
|
+
accuracy, loss = model.evaluate(x_test, y_test)
|
48
|
+
puts "accuracy: #{accuracy}"
|
49
|
+
puts "loss: #{loss}"
|
45
50
|
```
|
46
51
|
|
47
52
|
When create a model with 'define by run' style:
|
@@ -71,6 +76,10 @@ model = MLP.new
|
|
71
76
|
model.setup(Adam.new, SoftmaxCrossEntropy.new)
|
72
77
|
|
73
78
|
model.train(x_train, y_train, 10, batch_size: 128, test: [x_test, y_test])
|
79
|
+
|
80
|
+
accuracy, loss = model.evaluate(x_test, y_test)
|
81
|
+
puts "accuracy: #{accuracy}"
|
82
|
+
puts "loss: #{loss}"
|
74
83
|
```
|
75
84
|
|
76
85
|
Please refer to examples for basic usage.
|
@@ -86,6 +95,19 @@ If you want to know more detailed information, please refer to the source code.
|
|
86
95
|
| Optimizers | SGD, Nesterov, AdaGrad, RMSProp, AdaDelta, RMSPropGraves, Adam, AdaBound |
|
87
96
|
| Losses | MeanSquaredError, MeanAbsoluteError, Hinge, HuberLoss, SoftmaxCrossEntropy, SigmoidCrossEntropy |
|
88
97
|
|
98
|
+
## Datasets
|
99
|
+
● Iris
|
100
|
+
● MNIST
|
101
|
+
● Fashion-MNIST
|
102
|
+
● CIFAR-10
|
103
|
+
● CIFAR-100
|
104
|
+
● STL-10
|
105
|
+
|
106
|
+
## Examples
|
107
|
+
● VAE
|
108
|
+
● DCGAN
|
109
|
+
● Pix2pix
|
110
|
+
|
89
111
|
## TODO
|
90
112
|
● Write a test.
|
91
113
|
● Write a document.
|
data/examples/dcgan/dcgan.rb
CHANGED
@@ -61,6 +61,9 @@ class Discriminator < Model
|
|
61
61
|
@l4 = Conv2D.new(64, 4, padding: true)
|
62
62
|
@l5 = Dense.new(1024)
|
63
63
|
@l6 = Dense.new(1)
|
64
|
+
@bn1 = BatchNormalization.new
|
65
|
+
@bn2 = BatchNormalization.new
|
66
|
+
@bn3 = BatchNormalization.new
|
64
67
|
end
|
65
68
|
|
66
69
|
def forward(x)
|
@@ -69,12 +72,15 @@ class Discriminator < Model
|
|
69
72
|
x = LeakyReLU.(x, 0.2)
|
70
73
|
|
71
74
|
x = @l2.(x)
|
75
|
+
x = @bn1.(x)
|
72
76
|
x = LeakyReLU.(x, 0.2)
|
73
77
|
|
74
78
|
x = @l3.(x)
|
79
|
+
x = @bn2.(x)
|
75
80
|
x = LeakyReLU.(x, 0.2)
|
76
81
|
|
77
82
|
x = @l4.(x)
|
83
|
+
x = @bn3.(x)
|
78
84
|
x = LeakyReLU.(x, 0.2)
|
79
85
|
|
80
86
|
x = Flatten.(x)
|
@@ -119,10 +125,11 @@ class DCGAN < Model
|
|
119
125
|
batch_size = x_batch.shape[0]
|
120
126
|
noise = Numo::SFloat.new(batch_size, 20).rand(-1, 1)
|
121
127
|
images = @gen.predict(noise)
|
122
|
-
|
123
|
-
|
128
|
+
y_real = Numo::SFloat.ones(batch_size, 1)
|
129
|
+
y_fake = Numo::SFloat.zeros(batch_size, 1)
|
124
130
|
@dis.enable_training
|
125
|
-
dis_loss = @dis.train_on_batch(
|
131
|
+
dis_loss = @dis.train_on_batch(x_batch, y_real)
|
132
|
+
dis_loss + @dis.train_on_batch(images, y_fake)
|
126
133
|
|
127
134
|
noise = Numo::SFloat.new(batch_size, 20).rand(-1, 1)
|
128
135
|
label = Numo::SFloat.cast([1] * batch_size).reshape(batch_size, 1)
|
data/examples/pix2pix/dcgan.rb
CHANGED
@@ -2,6 +2,8 @@ include DNN::Models
|
|
2
2
|
include DNN::Layers
|
3
3
|
|
4
4
|
class Generator < Model
|
5
|
+
attr_reader :generate_images
|
6
|
+
|
5
7
|
def initialize(input_shape)
|
6
8
|
super()
|
7
9
|
@input_shape = input_shape
|
@@ -25,6 +27,7 @@ class Generator < Model
|
|
25
27
|
@bn7 = BatchNormalization.new
|
26
28
|
@bn8 = BatchNormalization.new
|
27
29
|
@bn9 = BatchNormalization.new
|
30
|
+
@generate_images = nil
|
28
31
|
end
|
29
32
|
|
30
33
|
def forward(x)
|
@@ -72,6 +75,7 @@ class Generator < Model
|
|
72
75
|
|
73
76
|
x = @l11.(x)
|
74
77
|
x = Tanh.(x)
|
78
|
+
@generate_images = x.data
|
75
79
|
x
|
76
80
|
end
|
77
81
|
end
|
data/examples/pix2pix/train.rb
CHANGED
@@ -24,6 +24,7 @@ gen = Generator.new([32, 32, 1])
|
|
24
24
|
dis = Discriminator.new([32, 32, 1], [32, 32, 3])
|
25
25
|
dcgan = DCGAN.new(gen, dis)
|
26
26
|
|
27
|
+
gen.setup(Adam.new(alpha: 0.0002, beta1: 0.5), MeanAbsoluteError.new)
|
27
28
|
dis.setup(Adam.new(alpha: 0.00001, beta1: 0.1), SigmoidCrossEntropy.new)
|
28
29
|
dcgan.setup(Adam.new(alpha: 0.0002, beta1: 0.5), SigmoidCrossEntropy.new)
|
29
30
|
|
@@ -35,7 +36,9 @@ num_batchs = x_in.shape[0] / batch_size
|
|
35
36
|
(1..epochs).each do |epoch|
|
36
37
|
num_batchs.times do |index|
|
37
38
|
x_in, x_out = iter1.next_batch(batch_size)
|
38
|
-
|
39
|
+
gen_loss = gen.train_on_batch(x_in, x_out)
|
40
|
+
|
41
|
+
images = gen.generate_images
|
39
42
|
y_real = Numo::SFloat.ones(batch_size, 1)
|
40
43
|
y_fake = Numo::SFloat.zeros(batch_size, 1)
|
41
44
|
dis.enable_training
|
@@ -45,7 +48,7 @@ num_batchs = x_in.shape[0] / batch_size
|
|
45
48
|
x_in, x_out = iter2.next_batch(batch_size)
|
46
49
|
dcgan_loss = dcgan.train_on_batch(x_in, y_real)
|
47
50
|
|
48
|
-
puts "epoch: #{epoch}, index: #{index}, dis_loss: #{dis_loss}, dcgan_loss: #{dcgan_loss}"
|
51
|
+
puts "epoch: #{epoch}, index: #{index}, gen_loss: #{gen_loss}, dis_loss: #{dis_loss}, dcgan_loss: #{dcgan_loss}"
|
49
52
|
end
|
50
53
|
iter1.reset
|
51
54
|
iter2.reset
|
data/examples/vae.rb
CHANGED
@@ -9,16 +9,10 @@ include DNN::Optimizers
|
|
9
9
|
include DNN::Losses
|
10
10
|
|
11
11
|
x_train, y_train = DNN::MNIST.load_train
|
12
|
-
x_test, y_test = DNN::MNIST.load_test
|
13
12
|
|
14
13
|
x_train = Numo::SFloat.cast(x_train).reshape(x_train.shape[0], 784)
|
15
|
-
x_test = Numo::SFloat.cast(x_test).reshape(x_test.shape[0], 784)
|
16
14
|
|
17
15
|
x_train /= 255
|
18
|
-
x_test /= 255
|
19
|
-
|
20
|
-
y_train = DNN::Utils.to_categorical(y_train, 10, Numo::SFloat)
|
21
|
-
y_test = DNN::Utils.to_categorical(y_test, 10, Numo::SFloat)
|
22
16
|
|
23
17
|
$z_dim = 2
|
24
18
|
$z_mean = nil
|
data/lib/dnn/core/callbacks.rb
CHANGED
@@ -27,10 +27,11 @@ module DNN
|
|
27
27
|
|
28
28
|
# This callback wrap the lambda function.
|
29
29
|
class LambdaCallback < Callback
|
30
|
-
|
31
|
-
|
30
|
+
# @param [Symbol] event Event to execute callback.
|
31
|
+
# @yield Register the contents of the callback.
|
32
|
+
def initialize(event, &block)
|
32
33
|
instance_eval do
|
33
|
-
define_singleton_method(event) {
|
34
|
+
define_singleton_method(event) { block.call }
|
34
35
|
end
|
35
36
|
end
|
36
37
|
end
|
@@ -55,6 +56,9 @@ module DNN
|
|
55
56
|
end
|
56
57
|
|
57
58
|
# A callback to stop training the model early after test on batch.
|
59
|
+
# @param [Symbol] trigger A log that triggers early stopping.
|
60
|
+
# Specify one of train_loss, test_loss, test_accuracy.
|
61
|
+
# @param [Float] tolerance Tolerance value for early stopping.
|
58
62
|
class EarlyStopping < Callback
|
59
63
|
def initialize(trigger, tolerance)
|
60
64
|
@trigger = trigger
|
data/lib/dnn/core/error.rb
CHANGED
@@ -6,7 +6,7 @@ module DNN
|
|
6
6
|
return nil unless hash
|
7
7
|
initializer_class = DNN.const_get(hash[:class])
|
8
8
|
initializer = initializer_class.allocate
|
9
|
-
raise
|
9
|
+
raise DNNError, "#{initializer.class} is not an instance of #{self} class." unless initializer.is_a?(self)
|
10
10
|
initializer.load_hash(hash)
|
11
11
|
initializer
|
12
12
|
end
|
@@ -122,8 +122,8 @@ module DNN
|
|
122
122
|
|
123
123
|
def init_param(layer, param)
|
124
124
|
Xumo::SFloat.srand(@seed)
|
125
|
-
|
126
|
-
param.data = param.data.rand_norm / Math.sqrt(
|
125
|
+
num_prev_units = layer.input_shape.reduce(:*)
|
126
|
+
param.data = param.data.rand_norm / Math.sqrt(num_prev_units)
|
127
127
|
end
|
128
128
|
end
|
129
129
|
|
@@ -134,8 +134,8 @@ module DNN
|
|
134
134
|
|
135
135
|
def init_param(layer, param)
|
136
136
|
Xumo::SFloat.srand(@seed)
|
137
|
-
|
138
|
-
param.data = param.data.rand_norm / Math.sqrt(
|
137
|
+
num_prev_units = layer.input_shape.reduce(:*)
|
138
|
+
param.data = param.data.rand_norm / Math.sqrt(num_prev_units) * Math.sqrt(2)
|
139
139
|
end
|
140
140
|
end
|
141
141
|
|
data/lib/dnn/core/iterator.rb
CHANGED
@@ -21,7 +21,7 @@ module DNN
|
|
21
21
|
# @param [Integer] batch_size Required batch size.
|
22
22
|
# @return [Array] Returns the mini batch in the form [x_batch, y_batch].
|
23
23
|
def next_batch(batch_size)
|
24
|
-
raise
|
24
|
+
raise DNNError, "This iterator has not next batch. Please call reset." unless has_next?
|
25
25
|
if @indexes.length <= batch_size
|
26
26
|
batch_indexes = @indexes
|
27
27
|
@has_next = false
|
@@ -60,6 +60,9 @@ module DNN
|
|
60
60
|
@has_next
|
61
61
|
end
|
62
62
|
|
63
|
+
# Run a loop with all data separated by batch
|
64
|
+
# @param [Integer] batch_size Batch size.
|
65
|
+
# @yield Executes block by receiving the specified arguments (x_batch, y_batch).
|
63
66
|
def foreach(batch_size, &block)
|
64
67
|
steps = @last_round_down ? @num_datas / batch_size : (@num_datas.to_f / batch_size).ceil
|
65
68
|
steps.times do |step|
|
@@ -2,18 +2,14 @@ module DNN
|
|
2
2
|
module Layers
|
3
3
|
|
4
4
|
module LayerNode
|
5
|
-
def forward(
|
6
|
-
x =
|
7
|
-
|
5
|
+
def forward(input)
|
6
|
+
x = input.data
|
7
|
+
prev = (input.is_a?(Tensor) ? input.link : input)
|
8
8
|
y = forward_node(x)
|
9
|
-
link = Link.new(
|
9
|
+
link = Link.new(prev, self)
|
10
10
|
Tensor.new(y, link)
|
11
11
|
end
|
12
12
|
|
13
|
-
def backward(dy)
|
14
|
-
backward_node(dy)
|
15
|
-
end
|
16
|
-
|
17
13
|
def forward_node(x)
|
18
14
|
raise NotImplementedError, "Class '#{self.class.name}' has implement method 'forward_node'"
|
19
15
|
end
|
@@ -26,6 +22,7 @@ module DNN
|
|
26
22
|
# Super class of all layer classes.
|
27
23
|
class Layer
|
28
24
|
attr_reader :input_shape
|
25
|
+
attr_reader :output_shape
|
29
26
|
|
30
27
|
def self.call(x, *args)
|
31
28
|
new(*args).(x)
|
@@ -35,7 +32,7 @@ module DNN
|
|
35
32
|
return nil unless hash
|
36
33
|
layer_class = DNN.const_get(hash[:class])
|
37
34
|
layer = layer_class.allocate
|
38
|
-
raise
|
35
|
+
raise DNNError, "#{layer.class} is not an instance of #{self} class." unless layer.is_a?(self)
|
39
36
|
layer.load_hash(hash)
|
40
37
|
layer
|
41
38
|
end
|
@@ -45,18 +42,19 @@ module DNN
|
|
45
42
|
end
|
46
43
|
|
47
44
|
# Forward propagation and create a link.
|
48
|
-
# @param [Tensor]
|
45
|
+
# @param [Tensor | Param] input Input tensor or param.
|
49
46
|
# @return [Tensor] Output tensor.
|
50
|
-
def call(
|
51
|
-
|
52
|
-
build(
|
53
|
-
forward(
|
47
|
+
def call(input)
|
48
|
+
input = Tensor.new(input) if !input.is_a?(Tensor) && !input.is_a?(Param)
|
49
|
+
build(input.data.shape[1..-1]) unless built?
|
50
|
+
forward(input)
|
54
51
|
end
|
55
52
|
|
56
53
|
# Build the layer.
|
57
54
|
# @param [Array] input_shape Setting the shape of the input data.
|
58
55
|
def build(input_shape)
|
59
56
|
@input_shape = input_shape
|
57
|
+
@output_shape = compute_output_shape
|
60
58
|
@built = true
|
61
59
|
end
|
62
60
|
|
@@ -66,16 +64,16 @@ module DNN
|
|
66
64
|
end
|
67
65
|
|
68
66
|
# Forward propagation.
|
69
|
-
# @param [Tensor]
|
67
|
+
# @param [Tensor] input Input tensor or param.
|
70
68
|
# @return [Tensor] Output tensor.
|
71
|
-
def forward(
|
69
|
+
def forward(input)
|
72
70
|
raise NotImplementedError, "Class '#{self.class.name}' has implement method 'forward'"
|
73
71
|
end
|
74
72
|
|
75
73
|
# Please reimplement this method as needed.
|
76
74
|
# The default implementation return input_shape.
|
77
75
|
# @return [Array] Return the shape of the output data.
|
78
|
-
def
|
76
|
+
def compute_output_shape
|
79
77
|
@input_shape
|
80
78
|
end
|
81
79
|
|
@@ -135,60 +133,37 @@ module DNN
|
|
135
133
|
end
|
136
134
|
|
137
135
|
class InputLayer < Layer
|
138
|
-
include LayerNode
|
139
|
-
|
140
|
-
def self.call(input)
|
141
|
-
shape = input.is_a?(Tensor) ? input.data.shape : input.shape
|
142
|
-
new(shape[1..-1]).(input)
|
143
|
-
end
|
144
|
-
|
145
136
|
# @param [Array] input_dim_or_shape Setting the shape or dimension of the input data.
|
146
137
|
def initialize(input_dim_or_shape)
|
147
138
|
super()
|
148
139
|
@input_shape = input_dim_or_shape.is_a?(Array) ? input_dim_or_shape : [input_dim_or_shape]
|
149
140
|
end
|
150
141
|
|
151
|
-
def call(input)
|
152
|
-
build(@input_shape) unless built?
|
153
|
-
if input.is_a?(Tensor)
|
154
|
-
x = input.data
|
155
|
-
prev_link = input&.link
|
156
|
-
else
|
157
|
-
x = input
|
158
|
-
prev_link = nil
|
159
|
-
end
|
160
|
-
Tensor.new(forward_node(x), Link.new(prev_link, self))
|
161
|
-
end
|
162
|
-
|
163
142
|
def build(input_shape)
|
164
|
-
@
|
143
|
+
super(@input_shape)
|
165
144
|
end
|
166
145
|
|
167
|
-
def
|
146
|
+
def forward(x)
|
168
147
|
unless x.shape[1..-1] == @input_shape
|
169
|
-
raise
|
148
|
+
raise DNNShapeError, "The shape of x does not match the input shape. input shape is #{@input_shape}, but x shape is #{x.shape[1..-1]}."
|
170
149
|
end
|
171
150
|
x
|
172
151
|
end
|
173
152
|
|
174
|
-
def backward_node(dy)
|
175
|
-
dy
|
176
|
-
end
|
177
|
-
|
178
153
|
def to_proc
|
179
154
|
method(:call).to_proc
|
180
155
|
end
|
181
156
|
|
182
157
|
def >>(layer)
|
183
158
|
if RUBY_VERSION < "2.6.0"
|
184
|
-
raise
|
159
|
+
raise DNNError, "Function composition is not supported before ruby version 2.6.0."
|
185
160
|
end
|
186
161
|
to_proc >> layer
|
187
162
|
end
|
188
163
|
|
189
164
|
def <<(layer)
|
190
165
|
if RUBY_VERSION < "2.6.0"
|
191
|
-
raise
|
166
|
+
raise DNNError, "Function composition is not supported before ruby version 2.6.0."
|
192
167
|
end
|
193
168
|
to_proc << layer
|
194
169
|
end
|
@@ -267,10 +242,10 @@ module DNN
|
|
267
242
|
class Dense < Connection
|
268
243
|
include LayerNode
|
269
244
|
|
270
|
-
attr_reader :
|
245
|
+
attr_reader :num_units
|
271
246
|
|
272
|
-
# @param [Integer]
|
273
|
-
def initialize(
|
247
|
+
# @param [Integer] num_units Number of nodes.
|
248
|
+
def initialize(num_units,
|
274
249
|
weight_initializer: Initializers::RandomNormal.new,
|
275
250
|
bias_initializer: Initializers::Zeros.new,
|
276
251
|
weight_regularizer: nil,
|
@@ -278,17 +253,17 @@ module DNN
|
|
278
253
|
use_bias: true)
|
279
254
|
super(weight_initializer: weight_initializer, bias_initializer: bias_initializer,
|
280
255
|
weight_regularizer: weight_regularizer, bias_regularizer: bias_regularizer, use_bias: use_bias)
|
281
|
-
@
|
256
|
+
@num_units = num_units
|
282
257
|
end
|
283
258
|
|
284
259
|
def build(input_shape)
|
285
260
|
unless input_shape.length == 1
|
286
|
-
raise
|
261
|
+
raise DNNShapeError, "Input shape is #{input_shape}. But input shape must be 1 dimensional."
|
287
262
|
end
|
288
263
|
super
|
289
|
-
|
290
|
-
@weight.data = Xumo::SFloat.new(
|
291
|
-
@bias.data = Xumo::SFloat.new(@
|
264
|
+
num_prev_units = input_shape[0]
|
265
|
+
@weight.data = Xumo::SFloat.new(num_prev_units, @num_units)
|
266
|
+
@bias.data = Xumo::SFloat.new(@num_units) if @bias
|
292
267
|
init_weight_and_bias
|
293
268
|
end
|
294
269
|
|
@@ -307,16 +282,16 @@ module DNN
|
|
307
282
|
dy.dot(@weight.data.transpose)
|
308
283
|
end
|
309
284
|
|
310
|
-
def
|
311
|
-
[@
|
285
|
+
def compute_output_shape
|
286
|
+
[@num_units]
|
312
287
|
end
|
313
288
|
|
314
289
|
def to_hash
|
315
|
-
super(
|
290
|
+
super(num_units: @num_units)
|
316
291
|
end
|
317
292
|
|
318
293
|
def load_hash(hash)
|
319
|
-
initialize(hash[:
|
294
|
+
initialize(hash[:num_units],
|
320
295
|
weight_initializer: Initializers::Initializer.from_hash(hash[:weight_initializer]),
|
321
296
|
bias_initializer: Initializers::Initializer.from_hash(hash[:bias_initializer]),
|
322
297
|
weight_regularizer: Regularizers::Regularizer.from_hash(hash[:weight_regularizer]),
|
@@ -329,14 +304,14 @@ module DNN
|
|
329
304
|
include LayerNode
|
330
305
|
|
331
306
|
def forward_node(x)
|
332
|
-
x.reshape(x.shape[0],
|
307
|
+
x.reshape(x.shape[0], *@output_shape)
|
333
308
|
end
|
334
309
|
|
335
310
|
def backward_node(dy)
|
336
311
|
dy.reshape(dy.shape[0], *@input_shape)
|
337
312
|
end
|
338
313
|
|
339
|
-
def
|
314
|
+
def compute_output_shape
|
340
315
|
[@input_shape.reduce(:*)]
|
341
316
|
end
|
342
317
|
end
|
@@ -344,11 +319,13 @@ module DNN
|
|
344
319
|
class Reshape < Layer
|
345
320
|
include LayerNode
|
346
321
|
|
347
|
-
|
348
|
-
|
349
|
-
def initialize(output_shape)
|
322
|
+
def initialize(shape)
|
350
323
|
super()
|
351
|
-
@
|
324
|
+
@shape = shape
|
325
|
+
end
|
326
|
+
|
327
|
+
def compute_output_shape
|
328
|
+
@shape
|
352
329
|
end
|
353
330
|
|
354
331
|
def forward_node(x)
|
@@ -360,11 +337,11 @@ module DNN
|
|
360
337
|
end
|
361
338
|
|
362
339
|
def to_hash
|
363
|
-
super(
|
340
|
+
super(shape: @shape)
|
364
341
|
end
|
365
342
|
|
366
343
|
def load_hash(hash)
|
367
|
-
initialize(hash[:
|
344
|
+
initialize(hash[:shape])
|
368
345
|
end
|
369
346
|
end
|
370
347
|
|