ruby-dnn 0.16.2 → 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +22 -0
  3. data/examples/api-examples/early_stopping_example.rb +1 -1
  4. data/examples/api-examples/initializer_example.rb +1 -1
  5. data/examples/api-examples/regularizer_example.rb +1 -1
  6. data/examples/dcgan/dcgan.rb +10 -3
  7. data/examples/pix2pix/dcgan.rb +4 -0
  8. data/examples/pix2pix/train.rb +5 -2
  9. data/examples/vae.rb +0 -6
  10. data/lib/dnn/core/callbacks.rb +7 -3
  11. data/lib/dnn/core/error.rb +2 -2
  12. data/lib/dnn/core/initializers.rb +5 -5
  13. data/lib/dnn/core/iterator.rb +4 -1
  14. data/lib/dnn/core/layers/basic_layers.rb +42 -65
  15. data/lib/dnn/core/layers/cnn_layers.rb +34 -35
  16. data/lib/dnn/core/layers/embedding.rb +3 -24
  17. data/lib/dnn/core/layers/math_layers.rb +12 -0
  18. data/lib/dnn/core/layers/merge_layers.rb +13 -13
  19. data/lib/dnn/core/layers/normalizations.rb +4 -4
  20. data/lib/dnn/core/layers/rnn_layers.rb +46 -46
  21. data/lib/dnn/core/link.rb +8 -8
  22. data/lib/dnn/core/losses.rb +10 -20
  23. data/lib/dnn/core/models.rb +23 -46
  24. data/lib/dnn/core/monkey_patch.rb +10 -0
  25. data/lib/dnn/core/optimizers.rb +1 -2
  26. data/lib/dnn/core/param.rb +2 -2
  27. data/lib/dnn/core/regularizers.rb +1 -1
  28. data/lib/dnn/core/savers.rb +2 -2
  29. data/lib/dnn/core/tensor.rb +1 -1
  30. data/lib/dnn/datasets/cifar10.rb +1 -1
  31. data/lib/dnn/datasets/cifar100.rb +1 -1
  32. data/lib/dnn/datasets/downloader.rb +1 -1
  33. data/lib/dnn/datasets/fashion-mnist.rb +1 -1
  34. data/lib/dnn/datasets/iris.rb +1 -1
  35. data/lib/dnn/datasets/mnist.rb +1 -1
  36. data/lib/dnn/datasets/stl-10.rb +2 -2
  37. data/lib/dnn/version.rb +1 -1
  38. metadata +2 -2
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: b76b0eb5bf75a22e48726f93fa4faff413c196b1f1587ce145aba7042c84a532
4
- data.tar.gz: f32d09cb89391583f51a557e5f72024f618d943af3becc25f8f19905fa395b3b
3
+ metadata.gz: 0db9ac3047ba8c15d903ace901f5e4e332835d11dffca2f441664ae843049d1d
4
+ data.tar.gz: f1b4bf61da8a48b8ad483eb806ab443bb40f1b0d88573c2d901ae45299abf86d
5
5
  SHA512:
6
- metadata.gz: f70784c49f71420df424c2077b430c0cb837b31b5493f62831cf49b296d834045796f2bae7749fa3b4bfb3ca365a97da856ccbebcc666c1489020eb55ec408fc
7
- data.tar.gz: 52fd5cf850341b3d2bdd44da35f69254e508b0dd77e052a7217fe250a8d41a00d237b5ce0cb379e652bf2e6b7d0b99e95b3530ef5abe2d43a2b8f8c8f7dbca7e
6
+ metadata.gz: 880fe0688bb5b15c016fdddb15b18f5e0b3ba2a45ae36292182adf8def20d93ca3ae176747dbf3d1369ea28cdaaf23e7cd9e96d0a0c6c4bb92db27131d8f4d93
7
+ data.tar.gz: 4d00dc6831f0c82e0dc1b4128d98dc391456410cbf26e0d0192b19092333664774eccabd2db0a609b739c1c40664e148dcd69152cdc45394690e21d785e1acc0
data/README.md CHANGED
@@ -42,6 +42,11 @@ model << Dense.new(10)
42
42
  model.setup(Adam.new, SoftmaxCrossEntropy.new)
43
43
 
44
44
  model.train(x_train, y_train, 10, batch_size: 128, test: [x_test, y_test])
45
+
46
+
47
+ accuracy, loss = model.evaluate(x_test, y_test)
48
+ puts "accuracy: #{accuracy}"
49
+ puts "loss: #{loss}"
45
50
  ```
46
51
 
47
52
  When create a model with 'define by run' style:
@@ -71,6 +76,10 @@ model = MLP.new
71
76
  model.setup(Adam.new, SoftmaxCrossEntropy.new)
72
77
 
73
78
  model.train(x_train, y_train, 10, batch_size: 128, test: [x_test, y_test])
79
+
80
+ accuracy, loss = model.evaluate(x_test, y_test)
81
+ puts "accuracy: #{accuracy}"
82
+ puts "loss: #{loss}"
74
83
  ```
75
84
 
76
85
  Please refer to examples for basic usage.
@@ -86,6 +95,19 @@ If you want to know more detailed information, please refer to the source code.
86
95
  | Optimizers | SGD, Nesterov, AdaGrad, RMSProp, AdaDelta, RMSPropGraves, Adam, AdaBound |
87
96
  | Losses | MeanSquaredError, MeanAbsoluteError, Hinge, HuberLoss, SoftmaxCrossEntropy, SigmoidCrossEntropy |
88
97
 
98
+ ## Datasets
99
+ ● Iris
100
+ ● MNIST
101
+ ● Fashion-MNIST
102
+ ● CIFAR-10
103
+ ● CIFAR-100
104
+ ● STL-10
105
+
106
+ ## Examples
107
+ ● VAE
108
+ ● DCGAN
109
+ ● Pix2pix
110
+
89
111
  ## TODO
90
112
  ● Write a test.
91
113
  ● Write a document.
@@ -35,7 +35,7 @@ class MLP < Model
35
35
  end
36
36
 
37
37
  def forward(x)
38
- x = InputLayer.(x)
38
+ x = InputLayer.new(784).(x)
39
39
  x = @l1.(x)
40
40
  x = @bn1.(x)
41
41
  x = ReLU.(x)
@@ -36,7 +36,7 @@ class MLP < Model
36
36
  end
37
37
 
38
38
  def forward(x)
39
- x = InputLayer.(x)
39
+ x = InputLayer.new(784).(x)
40
40
  x = @l1.(x)
41
41
  x = @bn1.(x)
42
42
  x = ReLU.(x)
@@ -37,7 +37,7 @@ class MLP < Model
37
37
  end
38
38
 
39
39
  def forward(x)
40
- x = InputLayer.(x)
40
+ x = InputLayer.new(784).(x)
41
41
  x = @l1.(x)
42
42
  x = @bn1.(x)
43
43
  x = ReLU.(x)
@@ -61,6 +61,9 @@ class Discriminator < Model
61
61
  @l4 = Conv2D.new(64, 4, padding: true)
62
62
  @l5 = Dense.new(1024)
63
63
  @l6 = Dense.new(1)
64
+ @bn1 = BatchNormalization.new
65
+ @bn2 = BatchNormalization.new
66
+ @bn3 = BatchNormalization.new
64
67
  end
65
68
 
66
69
  def forward(x)
@@ -69,12 +72,15 @@ class Discriminator < Model
69
72
  x = LeakyReLU.(x, 0.2)
70
73
 
71
74
  x = @l2.(x)
75
+ x = @bn1.(x)
72
76
  x = LeakyReLU.(x, 0.2)
73
77
 
74
78
  x = @l3.(x)
79
+ x = @bn2.(x)
75
80
  x = LeakyReLU.(x, 0.2)
76
81
 
77
82
  x = @l4.(x)
83
+ x = @bn3.(x)
78
84
  x = LeakyReLU.(x, 0.2)
79
85
 
80
86
  x = Flatten.(x)
@@ -119,10 +125,11 @@ class DCGAN < Model
119
125
  batch_size = x_batch.shape[0]
120
126
  noise = Numo::SFloat.new(batch_size, 20).rand(-1, 1)
121
127
  images = @gen.predict(noise)
122
- x = x_batch.concatenate(images)
123
- y = Numo::SFloat.cast([1] * batch_size + [0] * batch_size).reshape(batch_size * 2, 1)
128
+ y_real = Numo::SFloat.ones(batch_size, 1)
129
+ y_fake = Numo::SFloat.zeros(batch_size, 1)
124
130
  @dis.enable_training
125
- dis_loss = @dis.train_on_batch(x, y)
131
+ dis_loss = @dis.train_on_batch(x_batch, y_real)
132
+ dis_loss + @dis.train_on_batch(images, y_fake)
126
133
 
127
134
  noise = Numo::SFloat.new(batch_size, 20).rand(-1, 1)
128
135
  label = Numo::SFloat.cast([1] * batch_size).reshape(batch_size, 1)
@@ -2,6 +2,8 @@ include DNN::Models
2
2
  include DNN::Layers
3
3
 
4
4
  class Generator < Model
5
+ attr_reader :generate_images
6
+
5
7
  def initialize(input_shape)
6
8
  super()
7
9
  @input_shape = input_shape
@@ -25,6 +27,7 @@ class Generator < Model
25
27
  @bn7 = BatchNormalization.new
26
28
  @bn8 = BatchNormalization.new
27
29
  @bn9 = BatchNormalization.new
30
+ @generate_images = nil
28
31
  end
29
32
 
30
33
  def forward(x)
@@ -72,6 +75,7 @@ class Generator < Model
72
75
 
73
76
  x = @l11.(x)
74
77
  x = Tanh.(x)
78
+ @generate_images = x.data
75
79
  x
76
80
  end
77
81
  end
@@ -24,6 +24,7 @@ gen = Generator.new([32, 32, 1])
24
24
  dis = Discriminator.new([32, 32, 1], [32, 32, 3])
25
25
  dcgan = DCGAN.new(gen, dis)
26
26
 
27
+ gen.setup(Adam.new(alpha: 0.0002, beta1: 0.5), MeanAbsoluteError.new)
27
28
  dis.setup(Adam.new(alpha: 0.00001, beta1: 0.1), SigmoidCrossEntropy.new)
28
29
  dcgan.setup(Adam.new(alpha: 0.0002, beta1: 0.5), SigmoidCrossEntropy.new)
29
30
 
@@ -35,7 +36,9 @@ num_batchs = x_in.shape[0] / batch_size
35
36
  (1..epochs).each do |epoch|
36
37
  num_batchs.times do |index|
37
38
  x_in, x_out = iter1.next_batch(batch_size)
38
- images = gen.predict(x_in)
39
+ gen_loss = gen.train_on_batch(x_in, x_out)
40
+
41
+ images = gen.generate_images
39
42
  y_real = Numo::SFloat.ones(batch_size, 1)
40
43
  y_fake = Numo::SFloat.zeros(batch_size, 1)
41
44
  dis.enable_training
@@ -45,7 +48,7 @@ num_batchs = x_in.shape[0] / batch_size
45
48
  x_in, x_out = iter2.next_batch(batch_size)
46
49
  dcgan_loss = dcgan.train_on_batch(x_in, y_real)
47
50
 
48
- puts "epoch: #{epoch}, index: #{index}, dis_loss: #{dis_loss}, dcgan_loss: #{dcgan_loss}"
51
+ puts "epoch: #{epoch}, index: #{index}, gen_loss: #{gen_loss}, dis_loss: #{dis_loss}, dcgan_loss: #{dcgan_loss}"
49
52
  end
50
53
  iter1.reset
51
54
  iter2.reset
@@ -9,16 +9,10 @@ include DNN::Optimizers
9
9
  include DNN::Losses
10
10
 
11
11
  x_train, y_train = DNN::MNIST.load_train
12
- x_test, y_test = DNN::MNIST.load_test
13
12
 
14
13
  x_train = Numo::SFloat.cast(x_train).reshape(x_train.shape[0], 784)
15
- x_test = Numo::SFloat.cast(x_test).reshape(x_test.shape[0], 784)
16
14
 
17
15
  x_train /= 255
18
- x_test /= 255
19
-
20
- y_train = DNN::Utils.to_categorical(y_train, 10, Numo::SFloat)
21
- y_test = DNN::Utils.to_categorical(y_test, 10, Numo::SFloat)
22
16
 
23
17
  $z_dim = 2
24
18
  $z_mean = nil
@@ -27,10 +27,11 @@ module DNN
27
27
 
28
28
  # This callback wrap the lambda function.
29
29
  class LambdaCallback < Callback
30
- def initialize(event, lambda = nil, &block)
31
- lambda = block unless lambda
30
+ # @param [Symbol] event Event to execute callback.
31
+ # @yield Register the contents of the callback.
32
+ def initialize(event, &block)
32
33
  instance_eval do
33
- define_singleton_method(event) { lambda.call }
34
+ define_singleton_method(event) { block.call }
34
35
  end
35
36
  end
36
37
  end
@@ -55,6 +56,9 @@ module DNN
55
56
  end
56
57
 
57
58
  # A callback to stop training the model early after test on batch.
59
+ # @param [Symbol] trigger A log that triggers early stopping.
60
+ # Specify one of train_loss, test_loss, test_accuracy.
61
+ # @param [Float] tolerance Tolerance value for early stopping.
58
62
  class EarlyStopping < Callback
59
63
  def initialize(trigger, tolerance)
60
64
  @trigger = trigger
@@ -1,5 +1,5 @@
1
1
  module DNN
2
- class DNN_Error < StandardError; end
2
+ class DNNError < StandardError; end
3
3
 
4
- class DNN_ShapeError < DNN_Error; end
4
+ class DNNShapeError < DNNError; end
5
5
  end
@@ -6,7 +6,7 @@ module DNN
6
6
  return nil unless hash
7
7
  initializer_class = DNN.const_get(hash[:class])
8
8
  initializer = initializer_class.allocate
9
- raise DNN_Error, "#{initializer.class} is not an instance of #{self} class." unless initializer.is_a?(self)
9
+ raise DNNError, "#{initializer.class} is not an instance of #{self} class." unless initializer.is_a?(self)
10
10
  initializer.load_hash(hash)
11
11
  initializer
12
12
  end
@@ -122,8 +122,8 @@ module DNN
122
122
 
123
123
  def init_param(layer, param)
124
124
  Xumo::SFloat.srand(@seed)
125
- num_prev_nodes = layer.input_shape.reduce(:*)
126
- param.data = param.data.rand_norm / Math.sqrt(num_prev_nodes)
125
+ num_prev_units = layer.input_shape.reduce(:*)
126
+ param.data = param.data.rand_norm / Math.sqrt(num_prev_units)
127
127
  end
128
128
  end
129
129
 
@@ -134,8 +134,8 @@ module DNN
134
134
 
135
135
  def init_param(layer, param)
136
136
  Xumo::SFloat.srand(@seed)
137
- num_prev_nodes = layer.input_shape.reduce(:*)
138
- param.data = param.data.rand_norm / Math.sqrt(num_prev_nodes) * Math.sqrt(2)
137
+ num_prev_units = layer.input_shape.reduce(:*)
138
+ param.data = param.data.rand_norm / Math.sqrt(num_prev_units) * Math.sqrt(2)
139
139
  end
140
140
  end
141
141
 
@@ -21,7 +21,7 @@ module DNN
21
21
  # @param [Integer] batch_size Required batch size.
22
22
  # @return [Array] Returns the mini batch in the form [x_batch, y_batch].
23
23
  def next_batch(batch_size)
24
- raise DNN_Error, "This iterator has not next batch. Please call reset." unless has_next?
24
+ raise DNNError, "This iterator has not next batch. Please call reset." unless has_next?
25
25
  if @indexes.length <= batch_size
26
26
  batch_indexes = @indexes
27
27
  @has_next = false
@@ -60,6 +60,9 @@ module DNN
60
60
  @has_next
61
61
  end
62
62
 
63
+ # Run a loop with all data separated by batch
64
+ # @param [Integer] batch_size Batch size.
65
+ # @yield Executes block by receiving the specified arguments (x_batch, y_batch).
63
66
  def foreach(batch_size, &block)
64
67
  steps = @last_round_down ? @num_datas / batch_size : (@num_datas.to_f / batch_size).ceil
65
68
  steps.times do |step|
@@ -2,18 +2,14 @@ module DNN
2
2
  module Layers
3
3
 
4
4
  module LayerNode
5
- def forward(input_tensor)
6
- x = input_tensor.data
7
- prev_link = (input_tensor.is_a?(Tensor) ? input_tensor.link : input_tensor)
5
+ def forward(input)
6
+ x = input.data
7
+ prev = (input.is_a?(Tensor) ? input.link : input)
8
8
  y = forward_node(x)
9
- link = Link.new(prev_link, self)
9
+ link = Link.new(prev, self)
10
10
  Tensor.new(y, link)
11
11
  end
12
12
 
13
- def backward(dy)
14
- backward_node(dy)
15
- end
16
-
17
13
  def forward_node(x)
18
14
  raise NotImplementedError, "Class '#{self.class.name}' has implement method 'forward_node'"
19
15
  end
@@ -26,6 +22,7 @@ module DNN
26
22
  # Super class of all layer classes.
27
23
  class Layer
28
24
  attr_reader :input_shape
25
+ attr_reader :output_shape
29
26
 
30
27
  def self.call(x, *args)
31
28
  new(*args).(x)
@@ -35,7 +32,7 @@ module DNN
35
32
  return nil unless hash
36
33
  layer_class = DNN.const_get(hash[:class])
37
34
  layer = layer_class.allocate
38
- raise DNN_Error, "#{layer.class} is not an instance of #{self} class." unless layer.is_a?(self)
35
+ raise DNNError, "#{layer.class} is not an instance of #{self} class." unless layer.is_a?(self)
39
36
  layer.load_hash(hash)
40
37
  layer
41
38
  end
@@ -45,18 +42,19 @@ module DNN
45
42
  end
46
43
 
47
44
  # Forward propagation and create a link.
48
- # @param [Tensor] input_tensor Input tensor.
45
+ # @param [Tensor | Param] input Input tensor or param.
49
46
  # @return [Tensor] Output tensor.
50
- def call(input_tensor)
51
- input_tensor = Tensor.new(input_tensor) if !input_tensor.is_a?(Tensor) && !input_tensor.is_a?(Param)
52
- build(input_tensor.data.shape[1..-1]) unless built?
53
- forward(input_tensor)
47
+ def call(input)
48
+ input = Tensor.new(input) if !input.is_a?(Tensor) && !input.is_a?(Param)
49
+ build(input.data.shape[1..-1]) unless built?
50
+ forward(input)
54
51
  end
55
52
 
56
53
  # Build the layer.
57
54
  # @param [Array] input_shape Setting the shape of the input data.
58
55
  def build(input_shape)
59
56
  @input_shape = input_shape
57
+ @output_shape = compute_output_shape
60
58
  @built = true
61
59
  end
62
60
 
@@ -66,16 +64,16 @@ module DNN
66
64
  end
67
65
 
68
66
  # Forward propagation.
69
- # @param [Tensor] input_tensor Input tensor.
67
+ # @param [Tensor] input Input tensor or param.
70
68
  # @return [Tensor] Output tensor.
71
- def forward(input_tensor)
69
+ def forward(input)
72
70
  raise NotImplementedError, "Class '#{self.class.name}' has implement method 'forward'"
73
71
  end
74
72
 
75
73
  # Please reimplement this method as needed.
76
74
  # The default implementation return input_shape.
77
75
  # @return [Array] Return the shape of the output data.
78
- def output_shape
76
+ def compute_output_shape
79
77
  @input_shape
80
78
  end
81
79
 
@@ -135,60 +133,37 @@ module DNN
135
133
  end
136
134
 
137
135
  class InputLayer < Layer
138
- include LayerNode
139
-
140
- def self.call(input)
141
- shape = input.is_a?(Tensor) ? input.data.shape : input.shape
142
- new(shape[1..-1]).(input)
143
- end
144
-
145
136
  # @param [Array] input_dim_or_shape Setting the shape or dimension of the input data.
146
137
  def initialize(input_dim_or_shape)
147
138
  super()
148
139
  @input_shape = input_dim_or_shape.is_a?(Array) ? input_dim_or_shape : [input_dim_or_shape]
149
140
  end
150
141
 
151
- def call(input)
152
- build(@input_shape) unless built?
153
- if input.is_a?(Tensor)
154
- x = input.data
155
- prev_link = input&.link
156
- else
157
- x = input
158
- prev_link = nil
159
- end
160
- Tensor.new(forward_node(x), Link.new(prev_link, self))
161
- end
162
-
163
142
  def build(input_shape)
164
- @built = true
143
+ super(@input_shape)
165
144
  end
166
145
 
167
- def forward_node(x)
146
+ def forward(x)
168
147
  unless x.shape[1..-1] == @input_shape
169
- raise DNN_ShapeError, "The shape of x does not match the input shape. input shape is #{@input_shape}, but x shape is #{x.shape[1..-1]}."
148
+ raise DNNShapeError, "The shape of x does not match the input shape. input shape is #{@input_shape}, but x shape is #{x.shape[1..-1]}."
170
149
  end
171
150
  x
172
151
  end
173
152
 
174
- def backward_node(dy)
175
- dy
176
- end
177
-
178
153
  def to_proc
179
154
  method(:call).to_proc
180
155
  end
181
156
 
182
157
  def >>(layer)
183
158
  if RUBY_VERSION < "2.6.0"
184
- raise DNN_Error, "Function composition is not supported before ruby version 2.6.0."
159
+ raise DNNError, "Function composition is not supported before ruby version 2.6.0."
185
160
  end
186
161
  to_proc >> layer
187
162
  end
188
163
 
189
164
  def <<(layer)
190
165
  if RUBY_VERSION < "2.6.0"
191
- raise DNN_Error, "Function composition is not supported before ruby version 2.6.0."
166
+ raise DNNError, "Function composition is not supported before ruby version 2.6.0."
192
167
  end
193
168
  to_proc << layer
194
169
  end
@@ -267,10 +242,10 @@ module DNN
267
242
  class Dense < Connection
268
243
  include LayerNode
269
244
 
270
- attr_reader :num_nodes
245
+ attr_reader :num_units
271
246
 
272
- # @param [Integer] num_nodes Number of nodes.
273
- def initialize(num_nodes,
247
+ # @param [Integer] num_units Number of nodes.
248
+ def initialize(num_units,
274
249
  weight_initializer: Initializers::RandomNormal.new,
275
250
  bias_initializer: Initializers::Zeros.new,
276
251
  weight_regularizer: nil,
@@ -278,17 +253,17 @@ module DNN
278
253
  use_bias: true)
279
254
  super(weight_initializer: weight_initializer, bias_initializer: bias_initializer,
280
255
  weight_regularizer: weight_regularizer, bias_regularizer: bias_regularizer, use_bias: use_bias)
281
- @num_nodes = num_nodes
256
+ @num_units = num_units
282
257
  end
283
258
 
284
259
  def build(input_shape)
285
260
  unless input_shape.length == 1
286
- raise DNN_ShapeError, "Input shape is #{input_shape}. But input shape must be 1 dimensional."
261
+ raise DNNShapeError, "Input shape is #{input_shape}. But input shape must be 1 dimensional."
287
262
  end
288
263
  super
289
- num_prev_nodes = input_shape[0]
290
- @weight.data = Xumo::SFloat.new(num_prev_nodes, @num_nodes)
291
- @bias.data = Xumo::SFloat.new(@num_nodes) if @bias
264
+ num_prev_units = input_shape[0]
265
+ @weight.data = Xumo::SFloat.new(num_prev_units, @num_units)
266
+ @bias.data = Xumo::SFloat.new(@num_units) if @bias
292
267
  init_weight_and_bias
293
268
  end
294
269
 
@@ -307,16 +282,16 @@ module DNN
307
282
  dy.dot(@weight.data.transpose)
308
283
  end
309
284
 
310
- def output_shape
311
- [@num_nodes]
285
+ def compute_output_shape
286
+ [@num_units]
312
287
  end
313
288
 
314
289
  def to_hash
315
- super(num_nodes: @num_nodes)
290
+ super(num_units: @num_units)
316
291
  end
317
292
 
318
293
  def load_hash(hash)
319
- initialize(hash[:num_nodes],
294
+ initialize(hash[:num_units],
320
295
  weight_initializer: Initializers::Initializer.from_hash(hash[:weight_initializer]),
321
296
  bias_initializer: Initializers::Initializer.from_hash(hash[:bias_initializer]),
322
297
  weight_regularizer: Regularizers::Regularizer.from_hash(hash[:weight_regularizer]),
@@ -329,14 +304,14 @@ module DNN
329
304
  include LayerNode
330
305
 
331
306
  def forward_node(x)
332
- x.reshape(x.shape[0], *output_shape)
307
+ x.reshape(x.shape[0], *@output_shape)
333
308
  end
334
309
 
335
310
  def backward_node(dy)
336
311
  dy.reshape(dy.shape[0], *@input_shape)
337
312
  end
338
313
 
339
- def output_shape
314
+ def compute_output_shape
340
315
  [@input_shape.reduce(:*)]
341
316
  end
342
317
  end
@@ -344,11 +319,13 @@ module DNN
344
319
  class Reshape < Layer
345
320
  include LayerNode
346
321
 
347
- attr_reader :output_shape
348
-
349
- def initialize(output_shape)
322
+ def initialize(shape)
350
323
  super()
351
- @output_shape = output_shape
324
+ @shape = shape
325
+ end
326
+
327
+ def compute_output_shape
328
+ @shape
352
329
  end
353
330
 
354
331
  def forward_node(x)
@@ -360,11 +337,11 @@ module DNN
360
337
  end
361
338
 
362
339
  def to_hash
363
- super(output_shape: @output_shape)
340
+ super(shape: @shape)
364
341
  end
365
342
 
366
343
  def load_hash(hash)
367
- initialize(hash[:output_shape])
344
+ initialize(hash[:shape])
368
345
  end
369
346
  end
370
347