ruby-dnn 0.16.2 → 1.0.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (38) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +22 -0
  3. data/examples/api-examples/early_stopping_example.rb +1 -1
  4. data/examples/api-examples/initializer_example.rb +1 -1
  5. data/examples/api-examples/regularizer_example.rb +1 -1
  6. data/examples/dcgan/dcgan.rb +10 -3
  7. data/examples/pix2pix/dcgan.rb +4 -0
  8. data/examples/pix2pix/train.rb +5 -2
  9. data/examples/vae.rb +0 -6
  10. data/lib/dnn/core/callbacks.rb +7 -3
  11. data/lib/dnn/core/error.rb +2 -2
  12. data/lib/dnn/core/initializers.rb +5 -5
  13. data/lib/dnn/core/iterator.rb +4 -1
  14. data/lib/dnn/core/layers/basic_layers.rb +42 -65
  15. data/lib/dnn/core/layers/cnn_layers.rb +34 -35
  16. data/lib/dnn/core/layers/embedding.rb +3 -24
  17. data/lib/dnn/core/layers/math_layers.rb +12 -0
  18. data/lib/dnn/core/layers/merge_layers.rb +13 -13
  19. data/lib/dnn/core/layers/normalizations.rb +4 -4
  20. data/lib/dnn/core/layers/rnn_layers.rb +46 -46
  21. data/lib/dnn/core/link.rb +8 -8
  22. data/lib/dnn/core/losses.rb +10 -20
  23. data/lib/dnn/core/models.rb +23 -46
  24. data/lib/dnn/core/monkey_patch.rb +10 -0
  25. data/lib/dnn/core/optimizers.rb +1 -2
  26. data/lib/dnn/core/param.rb +2 -2
  27. data/lib/dnn/core/regularizers.rb +1 -1
  28. data/lib/dnn/core/savers.rb +2 -2
  29. data/lib/dnn/core/tensor.rb +1 -1
  30. data/lib/dnn/datasets/cifar10.rb +1 -1
  31. data/lib/dnn/datasets/cifar100.rb +1 -1
  32. data/lib/dnn/datasets/downloader.rb +1 -1
  33. data/lib/dnn/datasets/fashion-mnist.rb +1 -1
  34. data/lib/dnn/datasets/iris.rb +1 -1
  35. data/lib/dnn/datasets/mnist.rb +1 -1
  36. data/lib/dnn/datasets/stl-10.rb +2 -2
  37. data/lib/dnn/version.rb +1 -1
  38. metadata +2 -2
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: b76b0eb5bf75a22e48726f93fa4faff413c196b1f1587ce145aba7042c84a532
4
- data.tar.gz: f32d09cb89391583f51a557e5f72024f618d943af3becc25f8f19905fa395b3b
3
+ metadata.gz: 0db9ac3047ba8c15d903ace901f5e4e332835d11dffca2f441664ae843049d1d
4
+ data.tar.gz: f1b4bf61da8a48b8ad483eb806ab443bb40f1b0d88573c2d901ae45299abf86d
5
5
  SHA512:
6
- metadata.gz: f70784c49f71420df424c2077b430c0cb837b31b5493f62831cf49b296d834045796f2bae7749fa3b4bfb3ca365a97da856ccbebcc666c1489020eb55ec408fc
7
- data.tar.gz: 52fd5cf850341b3d2bdd44da35f69254e508b0dd77e052a7217fe250a8d41a00d237b5ce0cb379e652bf2e6b7d0b99e95b3530ef5abe2d43a2b8f8c8f7dbca7e
6
+ metadata.gz: 880fe0688bb5b15c016fdddb15b18f5e0b3ba2a45ae36292182adf8def20d93ca3ae176747dbf3d1369ea28cdaaf23e7cd9e96d0a0c6c4bb92db27131d8f4d93
7
+ data.tar.gz: 4d00dc6831f0c82e0dc1b4128d98dc391456410cbf26e0d0192b19092333664774eccabd2db0a609b739c1c40664e148dcd69152cdc45394690e21d785e1acc0
data/README.md CHANGED
@@ -42,6 +42,11 @@ model << Dense.new(10)
42
42
  model.setup(Adam.new, SoftmaxCrossEntropy.new)
43
43
 
44
44
  model.train(x_train, y_train, 10, batch_size: 128, test: [x_test, y_test])
45
+
46
+
47
+ accuracy, loss = model.evaluate(x_test, y_test)
48
+ puts "accuracy: #{accuracy}"
49
+ puts "loss: #{loss}"
45
50
  ```
46
51
 
47
52
  When create a model with 'define by run' style:
@@ -71,6 +76,10 @@ model = MLP.new
71
76
  model.setup(Adam.new, SoftmaxCrossEntropy.new)
72
77
 
73
78
  model.train(x_train, y_train, 10, batch_size: 128, test: [x_test, y_test])
79
+
80
+ accuracy, loss = model.evaluate(x_test, y_test)
81
+ puts "accuracy: #{accuracy}"
82
+ puts "loss: #{loss}"
74
83
  ```
75
84
 
76
85
  Please refer to examples for basic usage.
@@ -86,6 +95,19 @@ If you want to know more detailed information, please refer to the source code.
86
95
  | Optimizers | SGD, Nesterov, AdaGrad, RMSProp, AdaDelta, RMSPropGraves, Adam, AdaBound |
87
96
  | Losses | MeanSquaredError, MeanAbsoluteError, Hinge, HuberLoss, SoftmaxCrossEntropy, SigmoidCrossEntropy |
88
97
 
98
+ ## Datasets
99
+ ● Iris
100
+ ● MNIST
101
+ ● Fashion-MNIST
102
+ ● CIFAR-10
103
+ ● CIFAR-100
104
+ ● STL-10
105
+
106
+ ## Examples
107
+ ● VAE
108
+ ● DCGAN
109
+ ● Pix2pix
110
+
89
111
  ## TODO
90
112
  ● Write a test.
91
113
  ● Write a document.
@@ -35,7 +35,7 @@ class MLP < Model
35
35
  end
36
36
 
37
37
  def forward(x)
38
- x = InputLayer.(x)
38
+ x = InputLayer.new(784).(x)
39
39
  x = @l1.(x)
40
40
  x = @bn1.(x)
41
41
  x = ReLU.(x)
@@ -36,7 +36,7 @@ class MLP < Model
36
36
  end
37
37
 
38
38
  def forward(x)
39
- x = InputLayer.(x)
39
+ x = InputLayer.new(784).(x)
40
40
  x = @l1.(x)
41
41
  x = @bn1.(x)
42
42
  x = ReLU.(x)
@@ -37,7 +37,7 @@ class MLP < Model
37
37
  end
38
38
 
39
39
  def forward(x)
40
- x = InputLayer.(x)
40
+ x = InputLayer.new(784).(x)
41
41
  x = @l1.(x)
42
42
  x = @bn1.(x)
43
43
  x = ReLU.(x)
@@ -61,6 +61,9 @@ class Discriminator < Model
61
61
  @l4 = Conv2D.new(64, 4, padding: true)
62
62
  @l5 = Dense.new(1024)
63
63
  @l6 = Dense.new(1)
64
+ @bn1 = BatchNormalization.new
65
+ @bn2 = BatchNormalization.new
66
+ @bn3 = BatchNormalization.new
64
67
  end
65
68
 
66
69
  def forward(x)
@@ -69,12 +72,15 @@ class Discriminator < Model
69
72
  x = LeakyReLU.(x, 0.2)
70
73
 
71
74
  x = @l2.(x)
75
+ x = @bn1.(x)
72
76
  x = LeakyReLU.(x, 0.2)
73
77
 
74
78
  x = @l3.(x)
79
+ x = @bn2.(x)
75
80
  x = LeakyReLU.(x, 0.2)
76
81
 
77
82
  x = @l4.(x)
83
+ x = @bn3.(x)
78
84
  x = LeakyReLU.(x, 0.2)
79
85
 
80
86
  x = Flatten.(x)
@@ -119,10 +125,11 @@ class DCGAN < Model
119
125
  batch_size = x_batch.shape[0]
120
126
  noise = Numo::SFloat.new(batch_size, 20).rand(-1, 1)
121
127
  images = @gen.predict(noise)
122
- x = x_batch.concatenate(images)
123
- y = Numo::SFloat.cast([1] * batch_size + [0] * batch_size).reshape(batch_size * 2, 1)
128
+ y_real = Numo::SFloat.ones(batch_size, 1)
129
+ y_fake = Numo::SFloat.zeros(batch_size, 1)
124
130
  @dis.enable_training
125
- dis_loss = @dis.train_on_batch(x, y)
131
+ dis_loss = @dis.train_on_batch(x_batch, y_real)
132
+ dis_loss + @dis.train_on_batch(images, y_fake)
126
133
 
127
134
  noise = Numo::SFloat.new(batch_size, 20).rand(-1, 1)
128
135
  label = Numo::SFloat.cast([1] * batch_size).reshape(batch_size, 1)
@@ -2,6 +2,8 @@ include DNN::Models
2
2
  include DNN::Layers
3
3
 
4
4
  class Generator < Model
5
+ attr_reader :generate_images
6
+
5
7
  def initialize(input_shape)
6
8
  super()
7
9
  @input_shape = input_shape
@@ -25,6 +27,7 @@ class Generator < Model
25
27
  @bn7 = BatchNormalization.new
26
28
  @bn8 = BatchNormalization.new
27
29
  @bn9 = BatchNormalization.new
30
+ @generate_images = nil
28
31
  end
29
32
 
30
33
  def forward(x)
@@ -72,6 +75,7 @@ class Generator < Model
72
75
 
73
76
  x = @l11.(x)
74
77
  x = Tanh.(x)
78
+ @generate_images = x.data
75
79
  x
76
80
  end
77
81
  end
@@ -24,6 +24,7 @@ gen = Generator.new([32, 32, 1])
24
24
  dis = Discriminator.new([32, 32, 1], [32, 32, 3])
25
25
  dcgan = DCGAN.new(gen, dis)
26
26
 
27
+ gen.setup(Adam.new(alpha: 0.0002, beta1: 0.5), MeanAbsoluteError.new)
27
28
  dis.setup(Adam.new(alpha: 0.00001, beta1: 0.1), SigmoidCrossEntropy.new)
28
29
  dcgan.setup(Adam.new(alpha: 0.0002, beta1: 0.5), SigmoidCrossEntropy.new)
29
30
 
@@ -35,7 +36,9 @@ num_batchs = x_in.shape[0] / batch_size
35
36
  (1..epochs).each do |epoch|
36
37
  num_batchs.times do |index|
37
38
  x_in, x_out = iter1.next_batch(batch_size)
38
- images = gen.predict(x_in)
39
+ gen_loss = gen.train_on_batch(x_in, x_out)
40
+
41
+ images = gen.generate_images
39
42
  y_real = Numo::SFloat.ones(batch_size, 1)
40
43
  y_fake = Numo::SFloat.zeros(batch_size, 1)
41
44
  dis.enable_training
@@ -45,7 +48,7 @@ num_batchs = x_in.shape[0] / batch_size
45
48
  x_in, x_out = iter2.next_batch(batch_size)
46
49
  dcgan_loss = dcgan.train_on_batch(x_in, y_real)
47
50
 
48
- puts "epoch: #{epoch}, index: #{index}, dis_loss: #{dis_loss}, dcgan_loss: #{dcgan_loss}"
51
+ puts "epoch: #{epoch}, index: #{index}, gen_loss: #{gen_loss}, dis_loss: #{dis_loss}, dcgan_loss: #{dcgan_loss}"
49
52
  end
50
53
  iter1.reset
51
54
  iter2.reset
@@ -9,16 +9,10 @@ include DNN::Optimizers
9
9
  include DNN::Losses
10
10
 
11
11
  x_train, y_train = DNN::MNIST.load_train
12
- x_test, y_test = DNN::MNIST.load_test
13
12
 
14
13
  x_train = Numo::SFloat.cast(x_train).reshape(x_train.shape[0], 784)
15
- x_test = Numo::SFloat.cast(x_test).reshape(x_test.shape[0], 784)
16
14
 
17
15
  x_train /= 255
18
- x_test /= 255
19
-
20
- y_train = DNN::Utils.to_categorical(y_train, 10, Numo::SFloat)
21
- y_test = DNN::Utils.to_categorical(y_test, 10, Numo::SFloat)
22
16
 
23
17
  $z_dim = 2
24
18
  $z_mean = nil
@@ -27,10 +27,11 @@ module DNN
27
27
 
28
28
  # This callback wrap the lambda function.
29
29
  class LambdaCallback < Callback
30
- def initialize(event, lambda = nil, &block)
31
- lambda = block unless lambda
30
+ # @param [Symbol] event Event to execute callback.
31
+ # @yield Register the contents of the callback.
32
+ def initialize(event, &block)
32
33
  instance_eval do
33
- define_singleton_method(event) { lambda.call }
34
+ define_singleton_method(event) { block.call }
34
35
  end
35
36
  end
36
37
  end
@@ -55,6 +56,9 @@ module DNN
55
56
  end
56
57
 
57
58
  # A callback to stop training the model early after test on batch.
59
+ # @param [Symbol] trigger A log that triggers early stopping.
60
+ # Specify one of train_loss, test_loss, test_accuracy.
61
+ # @param [Float] tolerance Tolerance value for early stopping.
58
62
  class EarlyStopping < Callback
59
63
  def initialize(trigger, tolerance)
60
64
  @trigger = trigger
@@ -1,5 +1,5 @@
1
1
  module DNN
2
- class DNN_Error < StandardError; end
2
+ class DNNError < StandardError; end
3
3
 
4
- class DNN_ShapeError < DNN_Error; end
4
+ class DNNShapeError < DNNError; end
5
5
  end
@@ -6,7 +6,7 @@ module DNN
6
6
  return nil unless hash
7
7
  initializer_class = DNN.const_get(hash[:class])
8
8
  initializer = initializer_class.allocate
9
- raise DNN_Error, "#{initializer.class} is not an instance of #{self} class." unless initializer.is_a?(self)
9
+ raise DNNError, "#{initializer.class} is not an instance of #{self} class." unless initializer.is_a?(self)
10
10
  initializer.load_hash(hash)
11
11
  initializer
12
12
  end
@@ -122,8 +122,8 @@ module DNN
122
122
 
123
123
  def init_param(layer, param)
124
124
  Xumo::SFloat.srand(@seed)
125
- num_prev_nodes = layer.input_shape.reduce(:*)
126
- param.data = param.data.rand_norm / Math.sqrt(num_prev_nodes)
125
+ num_prev_units = layer.input_shape.reduce(:*)
126
+ param.data = param.data.rand_norm / Math.sqrt(num_prev_units)
127
127
  end
128
128
  end
129
129
 
@@ -134,8 +134,8 @@ module DNN
134
134
 
135
135
  def init_param(layer, param)
136
136
  Xumo::SFloat.srand(@seed)
137
- num_prev_nodes = layer.input_shape.reduce(:*)
138
- param.data = param.data.rand_norm / Math.sqrt(num_prev_nodes) * Math.sqrt(2)
137
+ num_prev_units = layer.input_shape.reduce(:*)
138
+ param.data = param.data.rand_norm / Math.sqrt(num_prev_units) * Math.sqrt(2)
139
139
  end
140
140
  end
141
141
 
@@ -21,7 +21,7 @@ module DNN
21
21
  # @param [Integer] batch_size Required batch size.
22
22
  # @return [Array] Returns the mini batch in the form [x_batch, y_batch].
23
23
  def next_batch(batch_size)
24
- raise DNN_Error, "This iterator has not next batch. Please call reset." unless has_next?
24
+ raise DNNError, "This iterator has not next batch. Please call reset." unless has_next?
25
25
  if @indexes.length <= batch_size
26
26
  batch_indexes = @indexes
27
27
  @has_next = false
@@ -60,6 +60,9 @@ module DNN
60
60
  @has_next
61
61
  end
62
62
 
63
+ # Run a loop with all data separated by batch
64
+ # @param [Integer] batch_size Batch size.
65
+ # @yield Executes block by receiving the specified arguments (x_batch, y_batch).
63
66
  def foreach(batch_size, &block)
64
67
  steps = @last_round_down ? @num_datas / batch_size : (@num_datas.to_f / batch_size).ceil
65
68
  steps.times do |step|
@@ -2,18 +2,14 @@ module DNN
2
2
  module Layers
3
3
 
4
4
  module LayerNode
5
- def forward(input_tensor)
6
- x = input_tensor.data
7
- prev_link = (input_tensor.is_a?(Tensor) ? input_tensor.link : input_tensor)
5
+ def forward(input)
6
+ x = input.data
7
+ prev = (input.is_a?(Tensor) ? input.link : input)
8
8
  y = forward_node(x)
9
- link = Link.new(prev_link, self)
9
+ link = Link.new(prev, self)
10
10
  Tensor.new(y, link)
11
11
  end
12
12
 
13
- def backward(dy)
14
- backward_node(dy)
15
- end
16
-
17
13
  def forward_node(x)
18
14
  raise NotImplementedError, "Class '#{self.class.name}' has implement method 'forward_node'"
19
15
  end
@@ -26,6 +22,7 @@ module DNN
26
22
  # Super class of all layer classes.
27
23
  class Layer
28
24
  attr_reader :input_shape
25
+ attr_reader :output_shape
29
26
 
30
27
  def self.call(x, *args)
31
28
  new(*args).(x)
@@ -35,7 +32,7 @@ module DNN
35
32
  return nil unless hash
36
33
  layer_class = DNN.const_get(hash[:class])
37
34
  layer = layer_class.allocate
38
- raise DNN_Error, "#{layer.class} is not an instance of #{self} class." unless layer.is_a?(self)
35
+ raise DNNError, "#{layer.class} is not an instance of #{self} class." unless layer.is_a?(self)
39
36
  layer.load_hash(hash)
40
37
  layer
41
38
  end
@@ -45,18 +42,19 @@ module DNN
45
42
  end
46
43
 
47
44
  # Forward propagation and create a link.
48
- # @param [Tensor] input_tensor Input tensor.
45
+ # @param [Tensor | Param] input Input tensor or param.
49
46
  # @return [Tensor] Output tensor.
50
- def call(input_tensor)
51
- input_tensor = Tensor.new(input_tensor) if !input_tensor.is_a?(Tensor) && !input_tensor.is_a?(Param)
52
- build(input_tensor.data.shape[1..-1]) unless built?
53
- forward(input_tensor)
47
+ def call(input)
48
+ input = Tensor.new(input) if !input.is_a?(Tensor) && !input.is_a?(Param)
49
+ build(input.data.shape[1..-1]) unless built?
50
+ forward(input)
54
51
  end
55
52
 
56
53
  # Build the layer.
57
54
  # @param [Array] input_shape Setting the shape of the input data.
58
55
  def build(input_shape)
59
56
  @input_shape = input_shape
57
+ @output_shape = compute_output_shape
60
58
  @built = true
61
59
  end
62
60
 
@@ -66,16 +64,16 @@ module DNN
66
64
  end
67
65
 
68
66
  # Forward propagation.
69
- # @param [Tensor] input_tensor Input tensor.
67
+ # @param [Tensor] input Input tensor or param.
70
68
  # @return [Tensor] Output tensor.
71
- def forward(input_tensor)
69
+ def forward(input)
72
70
  raise NotImplementedError, "Class '#{self.class.name}' has implement method 'forward'"
73
71
  end
74
72
 
75
73
  # Please reimplement this method as needed.
76
74
  # The default implementation return input_shape.
77
75
  # @return [Array] Return the shape of the output data.
78
- def output_shape
76
+ def compute_output_shape
79
77
  @input_shape
80
78
  end
81
79
 
@@ -135,60 +133,37 @@ module DNN
135
133
  end
136
134
 
137
135
  class InputLayer < Layer
138
- include LayerNode
139
-
140
- def self.call(input)
141
- shape = input.is_a?(Tensor) ? input.data.shape : input.shape
142
- new(shape[1..-1]).(input)
143
- end
144
-
145
136
  # @param [Array] input_dim_or_shape Setting the shape or dimension of the input data.
146
137
  def initialize(input_dim_or_shape)
147
138
  super()
148
139
  @input_shape = input_dim_or_shape.is_a?(Array) ? input_dim_or_shape : [input_dim_or_shape]
149
140
  end
150
141
 
151
- def call(input)
152
- build(@input_shape) unless built?
153
- if input.is_a?(Tensor)
154
- x = input.data
155
- prev_link = input&.link
156
- else
157
- x = input
158
- prev_link = nil
159
- end
160
- Tensor.new(forward_node(x), Link.new(prev_link, self))
161
- end
162
-
163
142
  def build(input_shape)
164
- @built = true
143
+ super(@input_shape)
165
144
  end
166
145
 
167
- def forward_node(x)
146
+ def forward(x)
168
147
  unless x.shape[1..-1] == @input_shape
169
- raise DNN_ShapeError, "The shape of x does not match the input shape. input shape is #{@input_shape}, but x shape is #{x.shape[1..-1]}."
148
+ raise DNNShapeError, "The shape of x does not match the input shape. input shape is #{@input_shape}, but x shape is #{x.shape[1..-1]}."
170
149
  end
171
150
  x
172
151
  end
173
152
 
174
- def backward_node(dy)
175
- dy
176
- end
177
-
178
153
  def to_proc
179
154
  method(:call).to_proc
180
155
  end
181
156
 
182
157
  def >>(layer)
183
158
  if RUBY_VERSION < "2.6.0"
184
- raise DNN_Error, "Function composition is not supported before ruby version 2.6.0."
159
+ raise DNNError, "Function composition is not supported before ruby version 2.6.0."
185
160
  end
186
161
  to_proc >> layer
187
162
  end
188
163
 
189
164
  def <<(layer)
190
165
  if RUBY_VERSION < "2.6.0"
191
- raise DNN_Error, "Function composition is not supported before ruby version 2.6.0."
166
+ raise DNNError, "Function composition is not supported before ruby version 2.6.0."
192
167
  end
193
168
  to_proc << layer
194
169
  end
@@ -267,10 +242,10 @@ module DNN
267
242
  class Dense < Connection
268
243
  include LayerNode
269
244
 
270
- attr_reader :num_nodes
245
+ attr_reader :num_units
271
246
 
272
- # @param [Integer] num_nodes Number of nodes.
273
- def initialize(num_nodes,
247
+ # @param [Integer] num_units Number of nodes.
248
+ def initialize(num_units,
274
249
  weight_initializer: Initializers::RandomNormal.new,
275
250
  bias_initializer: Initializers::Zeros.new,
276
251
  weight_regularizer: nil,
@@ -278,17 +253,17 @@ module DNN
278
253
  use_bias: true)
279
254
  super(weight_initializer: weight_initializer, bias_initializer: bias_initializer,
280
255
  weight_regularizer: weight_regularizer, bias_regularizer: bias_regularizer, use_bias: use_bias)
281
- @num_nodes = num_nodes
256
+ @num_units = num_units
282
257
  end
283
258
 
284
259
  def build(input_shape)
285
260
  unless input_shape.length == 1
286
- raise DNN_ShapeError, "Input shape is #{input_shape}. But input shape must be 1 dimensional."
261
+ raise DNNShapeError, "Input shape is #{input_shape}. But input shape must be 1 dimensional."
287
262
  end
288
263
  super
289
- num_prev_nodes = input_shape[0]
290
- @weight.data = Xumo::SFloat.new(num_prev_nodes, @num_nodes)
291
- @bias.data = Xumo::SFloat.new(@num_nodes) if @bias
264
+ num_prev_units = input_shape[0]
265
+ @weight.data = Xumo::SFloat.new(num_prev_units, @num_units)
266
+ @bias.data = Xumo::SFloat.new(@num_units) if @bias
292
267
  init_weight_and_bias
293
268
  end
294
269
 
@@ -307,16 +282,16 @@ module DNN
307
282
  dy.dot(@weight.data.transpose)
308
283
  end
309
284
 
310
- def output_shape
311
- [@num_nodes]
285
+ def compute_output_shape
286
+ [@num_units]
312
287
  end
313
288
 
314
289
  def to_hash
315
- super(num_nodes: @num_nodes)
290
+ super(num_units: @num_units)
316
291
  end
317
292
 
318
293
  def load_hash(hash)
319
- initialize(hash[:num_nodes],
294
+ initialize(hash[:num_units],
320
295
  weight_initializer: Initializers::Initializer.from_hash(hash[:weight_initializer]),
321
296
  bias_initializer: Initializers::Initializer.from_hash(hash[:bias_initializer]),
322
297
  weight_regularizer: Regularizers::Regularizer.from_hash(hash[:weight_regularizer]),
@@ -329,14 +304,14 @@ module DNN
329
304
  include LayerNode
330
305
 
331
306
  def forward_node(x)
332
- x.reshape(x.shape[0], *output_shape)
307
+ x.reshape(x.shape[0], *@output_shape)
333
308
  end
334
309
 
335
310
  def backward_node(dy)
336
311
  dy.reshape(dy.shape[0], *@input_shape)
337
312
  end
338
313
 
339
- def output_shape
314
+ def compute_output_shape
340
315
  [@input_shape.reduce(:*)]
341
316
  end
342
317
  end
@@ -344,11 +319,13 @@ module DNN
344
319
  class Reshape < Layer
345
320
  include LayerNode
346
321
 
347
- attr_reader :output_shape
348
-
349
- def initialize(output_shape)
322
+ def initialize(shape)
350
323
  super()
351
- @output_shape = output_shape
324
+ @shape = shape
325
+ end
326
+
327
+ def compute_output_shape
328
+ @shape
352
329
  end
353
330
 
354
331
  def forward_node(x)
@@ -360,11 +337,11 @@ module DNN
360
337
  end
361
338
 
362
339
  def to_hash
363
- super(output_shape: @output_shape)
340
+ super(shape: @shape)
364
341
  end
365
342
 
366
343
  def load_hash(hash)
367
- initialize(hash[:output_shape])
344
+ initialize(hash[:shape])
368
345
  end
369
346
  end
370
347