ruby-dnn 0.7.1 → 0.7.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/lib/dnn/core/initializers.rb +1 -1
- data/lib/dnn/core/model.rb +23 -3
- data/lib/dnn/lib/mnist.rb +1 -1
- data/lib/dnn/version.rb +1 -1
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: cf955bc9993a8c9495b144659cedc0cc97c1bde3bc1521be2453847d9d3afd91
|
4
|
+
data.tar.gz: b64395b3976619575ded55fd7ab508d09a8f0cde0dd0ab8a7d611dd7ae76f169
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 8dca44e876a1403b8b02e94936ce46f98a7dfa14a47874976c88fad12544f447481959cec6d563c6af6cfe9fa643f37a18c016b39e1da955f915f18027580f11
|
7
|
+
data.tar.gz: ad647ff01b6f9d9afe9a7dde002e68dce80dc1de57570d160e684f2a54fc635c9a0f6661fc563adbbe8196bd517f7c86f546302b185bac557209773f8d832da0
|
data/lib/dnn/core/model.rb
CHANGED
@@ -151,6 +151,7 @@ module DNN
|
|
151
151
|
end
|
152
152
|
|
153
153
|
def train_on_batch(x, y, &batch_proc)
|
154
|
+
input_data_shape_check(x, y)
|
154
155
|
x, y = batch_proc.call(x, y) if batch_proc
|
155
156
|
forward(x, true)
|
156
157
|
loss_value = loss(y)
|
@@ -160,6 +161,7 @@ module DNN
|
|
160
161
|
end
|
161
162
|
|
162
163
|
def accurate(x, y, batch_size = 1, &batch_proc)
|
164
|
+
input_data_shape_check(x, y)
|
163
165
|
batch_size = batch_size >= x.shape[0] ? x.shape[0] : batch_size
|
164
166
|
correct = 0
|
165
167
|
(x.shape[0].to_f / batch_size).ceil.times do |i|
|
@@ -185,6 +187,7 @@ module DNN
|
|
185
187
|
end
|
186
188
|
|
187
189
|
def predict(x)
|
190
|
+
input_data_shape_check(x)
|
188
191
|
forward(x, false)
|
189
192
|
end
|
190
193
|
|
@@ -239,7 +242,11 @@ module DNN
|
|
239
242
|
def get_prev_layer(layer)
|
240
243
|
layer_index = @layers.index(layer)
|
241
244
|
prev_layer = if layer_index == 0
|
242
|
-
@super_model
|
245
|
+
if @super_model
|
246
|
+
@super_model.layers[@super_model.layers.index(self) - 1]
|
247
|
+
else
|
248
|
+
self
|
249
|
+
end
|
243
250
|
else
|
244
251
|
@layers[layer_index - 1]
|
245
252
|
end
|
@@ -261,18 +268,31 @@ module DNN
|
|
261
268
|
end
|
262
269
|
end
|
263
270
|
|
271
|
+
def input_data_shape_check(x, y = nil)
|
272
|
+
unless @layers.first.shape == x.shape[1..-1]
|
273
|
+
raise DNN_ShapeError.new("The shape of x does not match the input shape. x shape is #{x.shape[1..-1]}, but input shape is #{@layers.first.shape}.")
|
274
|
+
end
|
275
|
+
if y && @layers.last.shape != y.shape[1..-1]
|
276
|
+
raise DNN_ShapeError.new("The shape of y does not match the input shape. y shape is #{y.shape[1..-1]}, but output shape is #{@layers.last.shape}.")
|
277
|
+
end
|
278
|
+
end
|
279
|
+
|
264
280
|
def layers_shape_check
|
265
281
|
@layers.each.with_index do |layer, i|
|
282
|
+
prev_shape = layer.prev_layer.shape
|
266
283
|
if layer.is_a?(Layers::Dense)
|
267
|
-
prev_shape = layer.prev_layer.shape
|
268
284
|
if prev_shape.length != 1
|
269
285
|
raise DNN_ShapeError.new("layer index(#{i}) Dense: The shape of the previous layer is #{prev_shape}. The shape of the previous layer must be 1 dimensional.")
|
270
286
|
end
|
271
287
|
elsif layer.is_a?(Layers::Conv2D) || layer.is_a?(Layers::MaxPool2D)
|
272
|
-
prev_shape = layer.prev_layer.shape
|
273
288
|
if prev_shape.length != 3
|
274
289
|
raise DNN_ShapeError.new("layer index(#{i}) Conv2D: The shape of the previous layer is #{prev_shape}. The shape of the previous layer must be 3 dimensional.")
|
275
290
|
end
|
291
|
+
elsif layer.is_a?(Layers::RNN)
|
292
|
+
if prev_shape.length != 2
|
293
|
+
layer_name = layer.class.name.match("\:\:(.+)$")[1]
|
294
|
+
raise DNN_ShapeError.new("layer index(#{i}) #{layer_name}: The shape of the previous layer is #{prev_shape}. The shape of the previous layer must be 3 dimensional.")
|
295
|
+
end
|
276
296
|
end
|
277
297
|
end
|
278
298
|
end
|
data/lib/dnn/lib/mnist.rb
CHANGED
data/lib/dnn/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: ruby-dnn
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.7.
|
4
|
+
version: 0.7.2
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- unagiootoro
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2018-
|
11
|
+
date: 2018-10-07 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|