ruby-dnn 0.14.2 → 0.14.3
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/README.md +1 -1
- data/examples/dcgan/dcgan.rb +1 -1
- data/examples/dcgan/train.rb +4 -3
- data/examples/mnist_define_by_run.rb +1 -1
- data/lib/dnn/core/cnn_layers.rb +23 -0
- data/lib/dnn/core/models.rb +20 -4
- data/lib/dnn/version.rb +1 -1
- metadata +3 -3
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 9e4633ce695dc370e62c6653d5ff07d81f9306625f7165750f8e455c4a28ca61
|
4
|
+
data.tar.gz: e4275720191b14592e2fb93581e9826e437ce9fcc75f6473cb963892d8f9f893
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: d9831b1b3a73423742bce7f989a13c7ac3d63426a2bd8366143be1717a1e90fbdf6aa79a9e90508e193e00fbcd120942b7ccd0c8b6b13eb9cc4e237471712030
|
7
|
+
data.tar.gz: 60344f9eb3645b0560dd9d1f63a86dcc554f7c9b923a3b84a0bc301264120d50ca034d2887c1cc41989f223891ace4d64c570c9935e2b3f6f467a05f0fab910f
|
data/README.md
CHANGED
data/examples/dcgan/dcgan.rb
CHANGED
@@ -121,6 +121,6 @@ class DCGAN < Model
|
|
121
121
|
label = Numo::SFloat.cast([1] * batch_size).reshape(batch_size, 1)
|
122
122
|
dcgan_loss = train_on_batch(noise, label)
|
123
123
|
|
124
|
-
{ dis_loss: dis_loss
|
124
|
+
{ dis_loss: dis_loss, dcgan_loss: dcgan_loss }
|
125
125
|
end
|
126
126
|
end
|
data/examples/dcgan/train.rb
CHANGED
@@ -19,11 +19,12 @@ dcgan = DCGAN.new(gen, dis)
|
|
19
19
|
|
20
20
|
dis.setup(Adam.new(alpha: 0.00001, beta1: 0.1), SigmoidCrossEntropy.new)
|
21
21
|
dcgan.setup(Adam.new(alpha: 0.0002, beta1: 0.5), SigmoidCrossEntropy.new)
|
22
|
+
dcgan.add_callback(CheckPoint.new("trained/dcgan_model"))
|
23
|
+
dcgan.predict1(Numo::SFloat.zeros(20))
|
22
24
|
|
23
25
|
x_train, * = MNIST.load_train
|
24
26
|
x_train = Numo::SFloat.cast(x_train)
|
25
27
|
x_train = x_train / 127.5 - 1
|
26
28
|
|
27
|
-
|
28
|
-
dcgan.
|
29
|
-
dcgan.train(x_train, x_train, epochs, batch_size: batch_size, last_round_down: true)
|
29
|
+
iter = DNN::Iterator.new(x_train, x_train, last_round_down: true)
|
30
|
+
dcgan.fit_by_iterator(iter, epochs, batch_size: batch_size)
|
data/lib/dnn/core/cnn_layers.rb
CHANGED
@@ -385,6 +385,29 @@ module DNN
|
|
385
385
|
end
|
386
386
|
end
|
387
387
|
|
388
|
+
class GlobalAvgPool2D < Layer
|
389
|
+
def build(input_shape)
|
390
|
+
unless input_shape.length == 3
|
391
|
+
raise DNN_ShapeError, "Input shape is #{input_shape}. But input shape must be 3 dimensional."
|
392
|
+
end
|
393
|
+
super
|
394
|
+
@avg_pool2d = AvgPool2D.new(input_shape[0..1])
|
395
|
+
@avg_pool2d.build(input_shape)
|
396
|
+
@flatten = Flatten.new
|
397
|
+
@flatten.build([1, 1, input_shape[2]])
|
398
|
+
end
|
399
|
+
|
400
|
+
def forward(x)
|
401
|
+
y = @avg_pool2d.forward(x)
|
402
|
+
@flatten.forward(y)
|
403
|
+
end
|
404
|
+
|
405
|
+
def backward(dy)
|
406
|
+
dy = @flatten.backward(dy)
|
407
|
+
@avg_pool2d.backward(dy)
|
408
|
+
end
|
409
|
+
end
|
410
|
+
|
388
411
|
class UnPool2D < Layer
|
389
412
|
include Conv2DUtils
|
390
413
|
|
data/lib/dnn/core/models.rb
CHANGED
@@ -370,13 +370,19 @@ module DNN
|
|
370
370
|
# @param [Array] stack All layers possessed by the model.
|
371
371
|
def initialize(stack = [])
|
372
372
|
super()
|
373
|
-
@stack =
|
373
|
+
@stack = []
|
374
|
+
stack.each do |layer|
|
375
|
+
add(layer)
|
376
|
+
end
|
374
377
|
end
|
375
378
|
|
376
379
|
# Add layer to the model.
|
377
380
|
# @param [DNN::Layers::Layer] layer Layer to add to the model.
|
378
381
|
# @return [DNN::Models::Model] Return self.
|
379
382
|
def add(layer)
|
383
|
+
if layer.is_a?(MergeLayers::MergeLayer)
|
384
|
+
raise TypeError, "layer: #{layer.class.name} should not be a DNN::MergeLayers::MergeLayer class."
|
385
|
+
end
|
380
386
|
unless layer.is_a?(Layers::Layer) || layer.is_a?(Model)
|
381
387
|
raise TypeError, "layer: #{layer.class.name} is not an instance of the DNN::Layers::Layer class or DNN::Models::Model class."
|
382
388
|
end
|
@@ -386,13 +392,23 @@ module DNN
|
|
386
392
|
|
387
393
|
alias << add
|
388
394
|
|
395
|
+
# Insert layer to the model by index position.
|
396
|
+
# @param [DNN::Layers::Layer] layer Layer to add to the model.
|
397
|
+
# @return [DNN::Models::Model] Return self.
|
398
|
+
def insert(index, layer)
|
399
|
+
if layer.is_a?(MergeLayers::MergeLayer)
|
400
|
+
raise TypeError, "layer: #{layer.class.name} should not be a DNN::MergeLayers::MergeLayer class."
|
401
|
+
end
|
402
|
+
unless layer.is_a?(Layers::Layer) || layer.is_a?(Model)
|
403
|
+
raise TypeError, "layer: #{layer.class.name} is not an instance of the DNN::Layers::Layer class or DNN::Models::Model class."
|
404
|
+
end
|
405
|
+
@stack.insert(index, layer)
|
406
|
+
end
|
407
|
+
|
389
408
|
# Remove layer to the model.
|
390
409
|
# @param [DNN::Layers::Layer] layer Layer to remove to the model.
|
391
410
|
# @return [Boolean] Return true if success for remove layer.
|
392
411
|
def remove(layer)
|
393
|
-
unless layer.is_a?(Layers::Layer) || layer.is_a?(Model)
|
394
|
-
raise TypeError, "layer: #{layer.class.name} is not an instance of the DNN::Layers::Layer class or DNN::Models::Model class."
|
395
|
-
end
|
396
412
|
@stack.delete(layer) ? true : false
|
397
413
|
end
|
398
414
|
|
data/lib/dnn/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: ruby-dnn
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.14.
|
4
|
+
version: 0.14.3
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- unagiootoro
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2019-
|
11
|
+
date: 2019-11-03 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|
@@ -170,7 +170,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
170
170
|
- !ruby/object:Gem::Version
|
171
171
|
version: '0'
|
172
172
|
requirements: []
|
173
|
-
rubygems_version: 3.0.
|
173
|
+
rubygems_version: 3.0.3
|
174
174
|
signing_key:
|
175
175
|
specification_version: 4
|
176
176
|
summary: ruby deep learning library.
|