ruby-dnn 0.10.2 → 0.10.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 4587938452bbdaf1a51f7fb9d7c3693570194d349204bdbfb8aa17c7f3f0db73
4
- data.tar.gz: 8b893be6c2a546ddae5c39c8a5a5967864b1af0c320a11e8d7ed72f5a738d144
3
+ metadata.gz: ad7089d958268c000dbe71a9fd44b5a65a9824fe792ad445c0ed58eff33f82fe
4
+ data.tar.gz: 249b43df73430bf42c15e435be51842ec6b8d0a52fb46e59b573ac5fe14d44f6
5
5
  SHA512:
6
- metadata.gz: 591c74427f134032c4a9ee0ed01fa0ced4b238a918ae1001c3d98b86cc9d5693be94e33f91f069dff58da0dd156a38746eb6665667bd872081f42d41230e77eb
7
- data.tar.gz: 4ac43a25bcee0f02453764fefc2dbd8e812158cd0d1f78eeb4d3c1ad1f3047634a831b49caf55202f9e9824c73bd7d2f40b318c8f22e005c9702d6d8fe2ceebf
6
+ metadata.gz: 8d904d65210bae39bb4ce3c0d4cb7bd31331663084a112a8ba11424ea73b06d7706ead1c2113e9812e98dd5f82510fe045c4deb13c92480b2897811bbf063885
7
+ data.tar.gz: 7c32096caf2a00edab6d27a65f3dfdd1663291be10ac50af256fc9486b866fd9455d5d4216aebb71b984fc41c3f6db7272e27aa39accaa0f0489990d8c562d3f
@@ -123,6 +123,9 @@ module DNN
123
123
  end
124
124
 
125
125
  def build(input_shape)
126
+ unless input_shape.length == 3
127
+ raise DNN_ShapeError.new("Input shape is #{input_shape}. But input shape must be 3 dimensional.")
128
+ end
126
129
  super
127
130
  prev_h, prev_w, num_prev_filter = *input_shape
128
131
  @weight.data = Xumo::SFloat.new(@filter_size.reduce(:*) * num_prev_filter, @num_filters)
@@ -228,16 +231,14 @@ module DNN
228
231
  end
229
232
 
230
233
  def build(input_shape)
234
+ unless input_shape.length == 3
235
+ raise DNN_ShapeError.new("Input shape is #{input_shape}. But input shape must be 3 dimensional.")
236
+ end
231
237
  super
232
238
  prev_h, prev_w, num_prev_filter = *input_shape
233
239
  @weight.data = Xumo::SFloat.new(@filter_size.reduce(:*) * @num_filters, num_prev_filter)
234
- @weight_initializer.init_param(self, @weight)
235
- @weight_regularizer.param = @weight if @weight_regularizer
236
- if @bias
237
- @bias.data = Xumo::SFloat.new(@num_filters)
238
- @bias_initializer.init_param(self, @bias)
239
- @bias_regularizer.param = @bias if @bias_regularizer
240
- end
240
+ @bias.data = Xumo::SFloat.new(@num_filters) if @bias
241
+ init_weight_and_bias
241
242
  if @padding == true
242
243
  out_h, out_w = calc_deconv2d_out_size(prev_h, prev_w, *@filter_size, 0, 0, @strides)
243
244
  @pad_size = calc_padding_size(out_h, out_w, prev_h, prev_w, @strides)
@@ -327,6 +328,9 @@ module DNN
327
328
  end
328
329
 
329
330
  def build(input_shape)
331
+ unless input_shape.length == 3
332
+ raise DNN_ShapeError.new("Input shape is #{input_shape}. But input shape must be 3 dimensional.")
333
+ end
330
334
  super
331
335
  prev_h, prev_w = input_shape[0..1]
332
336
  @num_channel = input_shape[2]
@@ -423,6 +427,9 @@ module DNN
423
427
  end
424
428
 
425
429
  def build(input_shape)
430
+ unless input_shape.length == 3
431
+ raise DNN_ShapeError.new("Input shape is #{input_shape}. But input shape must be 3 dimensional.")
432
+ end
426
433
  super
427
434
  prev_h, prev_w = input_shape[0..1]
428
435
  unpool_h, unpool_w = @unpool_size
@@ -188,6 +188,9 @@ module DNN
188
188
  end
189
189
 
190
190
  def build(input_shape)
191
+ unless input_shape.length == 1
192
+ raise DNN_ShapeError.new("Input shape is #{input_shape}. But input shape must be 1 dimensional.")
193
+ end
191
194
  super
192
195
  num_prev_nodes = input_shape[0]
193
196
  @weight.data = Xumo::SFloat.new(num_prev_nodes, @num_nodes)
@@ -120,7 +120,6 @@ module DNN
120
120
  @optimizer = optimizer
121
121
  @loss_func = loss_func
122
122
  build
123
- layers_shape_check
124
123
  end
125
124
 
126
125
  # Set optimizer and loss_func to model and recompile. But does not build layers.
@@ -137,7 +136,6 @@ module DNN
137
136
  layers_check
138
137
  @optimizer = optimizer
139
138
  @loss_func = loss_func
140
- layers_shape_check
141
139
  end
142
140
 
143
141
  def build(super_model = nil)
@@ -428,26 +426,6 @@ module DNN
428
426
  end
429
427
  end
430
428
 
431
- def layers_shape_check
432
- @layers.each.with_index do |layer, i|
433
- prev_shape = layer.input_shape
434
- if layer.is_a?(Layers::Dense)
435
- if prev_shape.length != 1
436
- raise DNN_ShapeError.new("layer index(#{i}) Dense: The shape of the previous layer is #{prev_shape}. The shape of the previous layer must be 1 dimensional.")
437
- end
438
- elsif layer.is_a?(Layers::Conv2D) || layer.is_a?(Layers::MaxPool2D)
439
- if prev_shape.length != 3
440
- raise DNN_ShapeError.new("layer index(#{i}) Conv2D: The shape of the previous layer is #{prev_shape}. The shape of the previous layer must be 3 dimensional.")
441
- end
442
- elsif layer.is_a?(Layers::RNN)
443
- if prev_shape.length != 2
444
- layer_name = layer.class.name.match("\:\:(.+)$")[1]
445
- raise DNN_ShapeError.new("layer index(#{i}) #{layer_name}: The shape of the previous layer is #{prev_shape}. The shape of the previous layer must be 3 dimensional.")
446
- end
447
- end
448
- end
449
- end
450
-
451
429
  def check_xy_type(x, y = nil)
452
430
  unless x.is_a?(Xumo::SFloat)
453
431
  raise TypeError.new("x:#{x.class.name} is not an instance of #{Xumo::SFloat.name} class.")
@@ -39,6 +39,9 @@ module DNN
39
39
  end
40
40
 
41
41
  def build(input_shape)
42
+ unless input_shape.length == 2
43
+ raise DNN_ShapeError.new("Input shape is #{input_shape}. But input shape must be 2 dimensional.")
44
+ end
42
45
  super
43
46
  @time_length = @input_shape[0]
44
47
  end
@@ -1,3 +1,3 @@
1
1
  module DNN
2
- VERSION = "0.10.2"
2
+ VERSION = "0.10.3"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: ruby-dnn
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.10.2
4
+ version: 0.10.3
5
5
  platform: ruby
6
6
  authors:
7
7
  - unagiootoro
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2019-06-23 00:00:00.000000000 Z
11
+ date: 2019-06-25 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray