ruby-dnn 0.7.1 → 0.7.2

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: e35a9f970e0b936026f31c819d3a6763f6362706e14c74a98771895e809835cf
4
- data.tar.gz: ea8918fc438cfcda9dba644871c56d12659137a949ab74351b6bea6ba69e2318
3
+ metadata.gz: cf955bc9993a8c9495b144659cedc0cc97c1bde3bc1521be2453847d9d3afd91
4
+ data.tar.gz: b64395b3976619575ded55fd7ab508d09a8f0cde0dd0ab8a7d611dd7ae76f169
5
5
  SHA512:
6
- metadata.gz: 7fcf64d42d502e9cfef831ba8a4e8c3afc2bade2632c42fb20a180baec494fecf026e087bfbf758f32f331243bac217ee1ab05e178ae1c12a7172c0c0ef96681
7
- data.tar.gz: a5064e28bec910a3f344af5777ca0a2da2495fd82a8f6a7cf1f33c8422d3e09459a636d90862a7db875cf2cb1b498926b7b1c5959b3e24c0cc12394a5de0f3dc
6
+ metadata.gz: 8dca44e876a1403b8b02e94936ce46f98a7dfa14a47874976c88fad12544f447481959cec6d563c6af6cfe9fa643f37a18c016b39e1da955f915f18027580f11
7
+ data.tar.gz: ad647ff01b6f9d9afe9a7dde002e68dce80dc1de57570d160e684f2a54fc635c9a0f6661fc563adbbe8196bd517f7c86f546302b185bac557209773f8d832da0
@@ -52,7 +52,7 @@ module DNN
52
52
  self.new(hash[:min], hash[:max])
53
53
  end
54
54
 
55
- def initialize(min = -0.25, max = 0.25)
55
+ def initialize(min = -0.05, max = 0.05)
56
56
  @min = min
57
57
  @max = max
58
58
  end
@@ -151,6 +151,7 @@ module DNN
151
151
  end
152
152
 
153
153
  def train_on_batch(x, y, &batch_proc)
154
+ input_data_shape_check(x, y)
154
155
  x, y = batch_proc.call(x, y) if batch_proc
155
156
  forward(x, true)
156
157
  loss_value = loss(y)
@@ -160,6 +161,7 @@ module DNN
160
161
  end
161
162
 
162
163
  def accurate(x, y, batch_size = 1, &batch_proc)
164
+ input_data_shape_check(x, y)
163
165
  batch_size = batch_size >= x.shape[0] ? x.shape[0] : batch_size
164
166
  correct = 0
165
167
  (x.shape[0].to_f / batch_size).ceil.times do |i|
@@ -185,6 +187,7 @@ module DNN
185
187
  end
186
188
 
187
189
  def predict(x)
190
+ input_data_shape_check(x)
188
191
  forward(x, false)
189
192
  end
190
193
 
@@ -239,7 +242,11 @@ module DNN
239
242
  def get_prev_layer(layer)
240
243
  layer_index = @layers.index(layer)
241
244
  prev_layer = if layer_index == 0
242
- @super_model.layers[@super_model.layers.index(self) - 1]
245
+ if @super_model
246
+ @super_model.layers[@super_model.layers.index(self) - 1]
247
+ else
248
+ self
249
+ end
243
250
  else
244
251
  @layers[layer_index - 1]
245
252
  end
@@ -261,18 +268,31 @@ module DNN
261
268
  end
262
269
  end
263
270
 
271
+ def input_data_shape_check(x, y = nil)
272
+ unless @layers.first.shape == x.shape[1..-1]
273
+ raise DNN_ShapeError.new("The shape of x does not match the input shape. x shape is #{x.shape[1..-1]}, but input shape is #{@layers.first.shape}.")
274
+ end
275
+ if y && @layers.last.shape != y.shape[1..-1]
276
+ raise DNN_ShapeError.new("The shape of y does not match the input shape. y shape is #{y.shape[1..-1]}, but output shape is #{@layers.last.shape}.")
277
+ end
278
+ end
279
+
264
280
  def layers_shape_check
265
281
  @layers.each.with_index do |layer, i|
282
+ prev_shape = layer.prev_layer.shape
266
283
  if layer.is_a?(Layers::Dense)
267
- prev_shape = layer.prev_layer.shape
268
284
  if prev_shape.length != 1
269
285
  raise DNN_ShapeError.new("layer index(#{i}) Dense: The shape of the previous layer is #{prev_shape}. The shape of the previous layer must be 1 dimensional.")
270
286
  end
271
287
  elsif layer.is_a?(Layers::Conv2D) || layer.is_a?(Layers::MaxPool2D)
272
- prev_shape = layer.prev_layer.shape
273
288
  if prev_shape.length != 3
274
289
  raise DNN_ShapeError.new("layer index(#{i}) Conv2D: The shape of the previous layer is #{prev_shape}. The shape of the previous layer must be 3 dimensional.")
275
290
  end
291
+ elsif layer.is_a?(Layers::RNN)
292
+ if prev_shape.length != 2
293
+ layer_name = layer.class.name.match("\:\:(.+)$")[1]
294
+ raise DNN_ShapeError.new("layer index(#{i}) #{layer_name}: The shape of the previous layer is #{prev_shape}. The shape of the previous layer must be 3 dimensional.")
295
+ end
276
296
  end
277
297
  end
278
298
  end
data/lib/dnn/lib/mnist.rb CHANGED
@@ -1,6 +1,6 @@
1
1
  require "open-uri"
2
2
  require "zlib"
3
- require "dnn/core/error"
3
+ require_relative "../core/error"
4
4
 
5
5
  module DNN
6
6
  module MNIST
data/lib/dnn/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module DNN
2
- VERSION = "0.7.1"
2
+ VERSION = "0.7.2"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: ruby-dnn
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.7.1
4
+ version: 0.7.2
5
5
  platform: ruby
6
6
  authors:
7
7
  - unagiootoro
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2018-09-19 00:00:00.000000000 Z
11
+ date: 2018-10-07 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray