ruby-dnn 0.7.1 → 0.7.2
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/lib/dnn/core/initializers.rb +1 -1
- data/lib/dnn/core/model.rb +23 -3
- data/lib/dnn/lib/mnist.rb +1 -1
- data/lib/dnn/version.rb +1 -1
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: cf955bc9993a8c9495b144659cedc0cc97c1bde3bc1521be2453847d9d3afd91
|
4
|
+
data.tar.gz: b64395b3976619575ded55fd7ab508d09a8f0cde0dd0ab8a7d611dd7ae76f169
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 8dca44e876a1403b8b02e94936ce46f98a7dfa14a47874976c88fad12544f447481959cec6d563c6af6cfe9fa643f37a18c016b39e1da955f915f18027580f11
|
7
|
+
data.tar.gz: ad647ff01b6f9d9afe9a7dde002e68dce80dc1de57570d160e684f2a54fc635c9a0f6661fc563adbbe8196bd517f7c86f546302b185bac557209773f8d832da0
|
data/lib/dnn/core/model.rb
CHANGED
@@ -151,6 +151,7 @@ module DNN
|
|
151
151
|
end
|
152
152
|
|
153
153
|
def train_on_batch(x, y, &batch_proc)
|
154
|
+
input_data_shape_check(x, y)
|
154
155
|
x, y = batch_proc.call(x, y) if batch_proc
|
155
156
|
forward(x, true)
|
156
157
|
loss_value = loss(y)
|
@@ -160,6 +161,7 @@ module DNN
|
|
160
161
|
end
|
161
162
|
|
162
163
|
def accurate(x, y, batch_size = 1, &batch_proc)
|
164
|
+
input_data_shape_check(x, y)
|
163
165
|
batch_size = batch_size >= x.shape[0] ? x.shape[0] : batch_size
|
164
166
|
correct = 0
|
165
167
|
(x.shape[0].to_f / batch_size).ceil.times do |i|
|
@@ -185,6 +187,7 @@ module DNN
|
|
185
187
|
end
|
186
188
|
|
187
189
|
def predict(x)
|
190
|
+
input_data_shape_check(x)
|
188
191
|
forward(x, false)
|
189
192
|
end
|
190
193
|
|
@@ -239,7 +242,11 @@ module DNN
|
|
239
242
|
def get_prev_layer(layer)
|
240
243
|
layer_index = @layers.index(layer)
|
241
244
|
prev_layer = if layer_index == 0
|
242
|
-
@super_model
|
245
|
+
if @super_model
|
246
|
+
@super_model.layers[@super_model.layers.index(self) - 1]
|
247
|
+
else
|
248
|
+
self
|
249
|
+
end
|
243
250
|
else
|
244
251
|
@layers[layer_index - 1]
|
245
252
|
end
|
@@ -261,18 +268,31 @@ module DNN
|
|
261
268
|
end
|
262
269
|
end
|
263
270
|
|
271
|
+
def input_data_shape_check(x, y = nil)
|
272
|
+
unless @layers.first.shape == x.shape[1..-1]
|
273
|
+
raise DNN_ShapeError.new("The shape of x does not match the input shape. x shape is #{x.shape[1..-1]}, but input shape is #{@layers.first.shape}.")
|
274
|
+
end
|
275
|
+
if y && @layers.last.shape != y.shape[1..-1]
|
276
|
+
raise DNN_ShapeError.new("The shape of y does not match the input shape. y shape is #{y.shape[1..-1]}, but output shape is #{@layers.last.shape}.")
|
277
|
+
end
|
278
|
+
end
|
279
|
+
|
264
280
|
def layers_shape_check
|
265
281
|
@layers.each.with_index do |layer, i|
|
282
|
+
prev_shape = layer.prev_layer.shape
|
266
283
|
if layer.is_a?(Layers::Dense)
|
267
|
-
prev_shape = layer.prev_layer.shape
|
268
284
|
if prev_shape.length != 1
|
269
285
|
raise DNN_ShapeError.new("layer index(#{i}) Dense: The shape of the previous layer is #{prev_shape}. The shape of the previous layer must be 1 dimensional.")
|
270
286
|
end
|
271
287
|
elsif layer.is_a?(Layers::Conv2D) || layer.is_a?(Layers::MaxPool2D)
|
272
|
-
prev_shape = layer.prev_layer.shape
|
273
288
|
if prev_shape.length != 3
|
274
289
|
raise DNN_ShapeError.new("layer index(#{i}) Conv2D: The shape of the previous layer is #{prev_shape}. The shape of the previous layer must be 3 dimensional.")
|
275
290
|
end
|
291
|
+
elsif layer.is_a?(Layers::RNN)
|
292
|
+
if prev_shape.length != 2
|
293
|
+
layer_name = layer.class.name.match("\:\:(.+)$")[1]
|
294
|
+
raise DNN_ShapeError.new("layer index(#{i}) #{layer_name}: The shape of the previous layer is #{prev_shape}. The shape of the previous layer must be 3 dimensional.")
|
295
|
+
end
|
276
296
|
end
|
277
297
|
end
|
278
298
|
end
|
data/lib/dnn/lib/mnist.rb
CHANGED
data/lib/dnn/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: ruby-dnn
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.7.
|
4
|
+
version: 0.7.2
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- unagiootoro
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2018-
|
11
|
+
date: 2018-10-07 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|