ruby-dnn 0.7.1 → 0.7.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: e35a9f970e0b936026f31c819d3a6763f6362706e14c74a98771895e809835cf
4
- data.tar.gz: ea8918fc438cfcda9dba644871c56d12659137a949ab74351b6bea6ba69e2318
3
+ metadata.gz: cf955bc9993a8c9495b144659cedc0cc97c1bde3bc1521be2453847d9d3afd91
4
+ data.tar.gz: b64395b3976619575ded55fd7ab508d09a8f0cde0dd0ab8a7d611dd7ae76f169
5
5
  SHA512:
6
- metadata.gz: 7fcf64d42d502e9cfef831ba8a4e8c3afc2bade2632c42fb20a180baec494fecf026e087bfbf758f32f331243bac217ee1ab05e178ae1c12a7172c0c0ef96681
7
- data.tar.gz: a5064e28bec910a3f344af5777ca0a2da2495fd82a8f6a7cf1f33c8422d3e09459a636d90862a7db875cf2cb1b498926b7b1c5959b3e24c0cc12394a5de0f3dc
6
+ metadata.gz: 8dca44e876a1403b8b02e94936ce46f98a7dfa14a47874976c88fad12544f447481959cec6d563c6af6cfe9fa643f37a18c016b39e1da955f915f18027580f11
7
+ data.tar.gz: ad647ff01b6f9d9afe9a7dde002e68dce80dc1de57570d160e684f2a54fc635c9a0f6661fc563adbbe8196bd517f7c86f546302b185bac557209773f8d832da0
@@ -52,7 +52,7 @@ module DNN
52
52
  self.new(hash[:min], hash[:max])
53
53
  end
54
54
 
55
- def initialize(min = -0.25, max = 0.25)
55
+ def initialize(min = -0.05, max = 0.05)
56
56
  @min = min
57
57
  @max = max
58
58
  end
@@ -151,6 +151,7 @@ module DNN
151
151
  end
152
152
 
153
153
  def train_on_batch(x, y, &batch_proc)
154
+ input_data_shape_check(x, y)
154
155
  x, y = batch_proc.call(x, y) if batch_proc
155
156
  forward(x, true)
156
157
  loss_value = loss(y)
@@ -160,6 +161,7 @@ module DNN
160
161
  end
161
162
 
162
163
  def accurate(x, y, batch_size = 1, &batch_proc)
164
+ input_data_shape_check(x, y)
163
165
  batch_size = batch_size >= x.shape[0] ? x.shape[0] : batch_size
164
166
  correct = 0
165
167
  (x.shape[0].to_f / batch_size).ceil.times do |i|
@@ -185,6 +187,7 @@ module DNN
185
187
  end
186
188
 
187
189
  def predict(x)
190
+ input_data_shape_check(x)
188
191
  forward(x, false)
189
192
  end
190
193
 
@@ -239,7 +242,11 @@ module DNN
239
242
  def get_prev_layer(layer)
240
243
  layer_index = @layers.index(layer)
241
244
  prev_layer = if layer_index == 0
242
- @super_model.layers[@super_model.layers.index(self) - 1]
245
+ if @super_model
246
+ @super_model.layers[@super_model.layers.index(self) - 1]
247
+ else
248
+ self
249
+ end
243
250
  else
244
251
  @layers[layer_index - 1]
245
252
  end
@@ -261,18 +268,31 @@ module DNN
261
268
  end
262
269
  end
263
270
 
271
+ def input_data_shape_check(x, y = nil)
272
+ unless @layers.first.shape == x.shape[1..-1]
273
+ raise DNN_ShapeError.new("The shape of x does not match the input shape. x shape is #{x.shape[1..-1]}, but input shape is #{@layers.first.shape}.")
274
+ end
275
+ if y && @layers.last.shape != y.shape[1..-1]
276
+ raise DNN_ShapeError.new("The shape of y does not match the input shape. y shape is #{y.shape[1..-1]}, but output shape is #{@layers.last.shape}.")
277
+ end
278
+ end
279
+
264
280
  def layers_shape_check
265
281
  @layers.each.with_index do |layer, i|
282
+ prev_shape = layer.prev_layer.shape
266
283
  if layer.is_a?(Layers::Dense)
267
- prev_shape = layer.prev_layer.shape
268
284
  if prev_shape.length != 1
269
285
  raise DNN_ShapeError.new("layer index(#{i}) Dense: The shape of the previous layer is #{prev_shape}. The shape of the previous layer must be 1 dimensional.")
270
286
  end
271
287
  elsif layer.is_a?(Layers::Conv2D) || layer.is_a?(Layers::MaxPool2D)
272
- prev_shape = layer.prev_layer.shape
273
288
  if prev_shape.length != 3
274
289
  raise DNN_ShapeError.new("layer index(#{i}) Conv2D: The shape of the previous layer is #{prev_shape}. The shape of the previous layer must be 3 dimensional.")
275
290
  end
291
+ elsif layer.is_a?(Layers::RNN)
292
+ if prev_shape.length != 2
293
+ layer_name = layer.class.name.match("\:\:(.+)$")[1]
294
+ raise DNN_ShapeError.new("layer index(#{i}) #{layer_name}: The shape of the previous layer is #{prev_shape}. The shape of the previous layer must be 3 dimensional.")
295
+ end
276
296
  end
277
297
  end
278
298
  end
data/lib/dnn/lib/mnist.rb CHANGED
@@ -1,6 +1,6 @@
1
1
  require "open-uri"
2
2
  require "zlib"
3
- require "dnn/core/error"
3
+ require_relative "../core/error"
4
4
 
5
5
  module DNN
6
6
  module MNIST
data/lib/dnn/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module DNN
2
- VERSION = "0.7.1"
2
+ VERSION = "0.7.2"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: ruby-dnn
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.7.1
4
+ version: 0.7.2
5
5
  platform: ruby
6
6
  authors:
7
7
  - unagiootoro
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2018-09-19 00:00:00.000000000 Z
11
+ date: 2018-10-07 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray