ruby-dnn 0.14.2 → 0.14.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +1 -1
- data/examples/dcgan/dcgan.rb +1 -1
- data/examples/dcgan/train.rb +4 -3
- data/examples/mnist_define_by_run.rb +1 -1
- data/lib/dnn/core/cnn_layers.rb +23 -0
- data/lib/dnn/core/models.rb +20 -4
- data/lib/dnn/version.rb +1 -1
- metadata +3 -3
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 9e4633ce695dc370e62c6653d5ff07d81f9306625f7165750f8e455c4a28ca61
|
4
|
+
data.tar.gz: e4275720191b14592e2fb93581e9826e437ce9fcc75f6473cb963892d8f9f893
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: d9831b1b3a73423742bce7f989a13c7ac3d63426a2bd8366143be1717a1e90fbdf6aa79a9e90508e193e00fbcd120942b7ccd0c8b6b13eb9cc4e237471712030
|
7
|
+
data.tar.gz: 60344f9eb3645b0560dd9d1f63a86dcc554f7c9b923a3b84a0bc301264120d50ca034d2887c1cc41989f223891ace4d64c570c9935e2b3f6f467a05f0fab910f
|
data/README.md
CHANGED
data/examples/dcgan/dcgan.rb
CHANGED
@@ -121,6 +121,6 @@ class DCGAN < Model
|
|
121
121
|
label = Numo::SFloat.cast([1] * batch_size).reshape(batch_size, 1)
|
122
122
|
dcgan_loss = train_on_batch(noise, label)
|
123
123
|
|
124
|
-
{ dis_loss: dis_loss
|
124
|
+
{ dis_loss: dis_loss, dcgan_loss: dcgan_loss }
|
125
125
|
end
|
126
126
|
end
|
data/examples/dcgan/train.rb
CHANGED
@@ -19,11 +19,12 @@ dcgan = DCGAN.new(gen, dis)
|
|
19
19
|
|
20
20
|
dis.setup(Adam.new(alpha: 0.00001, beta1: 0.1), SigmoidCrossEntropy.new)
|
21
21
|
dcgan.setup(Adam.new(alpha: 0.0002, beta1: 0.5), SigmoidCrossEntropy.new)
|
22
|
+
dcgan.add_callback(CheckPoint.new("trained/dcgan_model"))
|
23
|
+
dcgan.predict1(Numo::SFloat.zeros(20))
|
22
24
|
|
23
25
|
x_train, * = MNIST.load_train
|
24
26
|
x_train = Numo::SFloat.cast(x_train)
|
25
27
|
x_train = x_train / 127.5 - 1
|
26
28
|
|
27
|
-
|
28
|
-
dcgan.
|
29
|
-
dcgan.train(x_train, x_train, epochs, batch_size: batch_size, last_round_down: true)
|
29
|
+
iter = DNN::Iterator.new(x_train, x_train, last_round_down: true)
|
30
|
+
dcgan.fit_by_iterator(iter, epochs, batch_size: batch_size)
|
data/lib/dnn/core/cnn_layers.rb
CHANGED
@@ -385,6 +385,29 @@ module DNN
|
|
385
385
|
end
|
386
386
|
end
|
387
387
|
|
388
|
+
class GlobalAvgPool2D < Layer
|
389
|
+
def build(input_shape)
|
390
|
+
unless input_shape.length == 3
|
391
|
+
raise DNN_ShapeError, "Input shape is #{input_shape}. But input shape must be 3 dimensional."
|
392
|
+
end
|
393
|
+
super
|
394
|
+
@avg_pool2d = AvgPool2D.new(input_shape[0..1])
|
395
|
+
@avg_pool2d.build(input_shape)
|
396
|
+
@flatten = Flatten.new
|
397
|
+
@flatten.build([1, 1, input_shape[2]])
|
398
|
+
end
|
399
|
+
|
400
|
+
def forward(x)
|
401
|
+
y = @avg_pool2d.forward(x)
|
402
|
+
@flatten.forward(y)
|
403
|
+
end
|
404
|
+
|
405
|
+
def backward(dy)
|
406
|
+
dy = @flatten.backward(dy)
|
407
|
+
@avg_pool2d.backward(dy)
|
408
|
+
end
|
409
|
+
end
|
410
|
+
|
388
411
|
class UnPool2D < Layer
|
389
412
|
include Conv2DUtils
|
390
413
|
|
data/lib/dnn/core/models.rb
CHANGED
@@ -370,13 +370,19 @@ module DNN
|
|
370
370
|
# @param [Array] stack All layers possessed by the model.
|
371
371
|
def initialize(stack = [])
|
372
372
|
super()
|
373
|
-
@stack =
|
373
|
+
@stack = []
|
374
|
+
stack.each do |layer|
|
375
|
+
add(layer)
|
376
|
+
end
|
374
377
|
end
|
375
378
|
|
376
379
|
# Add layer to the model.
|
377
380
|
# @param [DNN::Layers::Layer] layer Layer to add to the model.
|
378
381
|
# @return [DNN::Models::Model] Return self.
|
379
382
|
def add(layer)
|
383
|
+
if layer.is_a?(MergeLayers::MergeLayer)
|
384
|
+
raise TypeError, "layer: #{layer.class.name} should not be a DNN::MergeLayers::MergeLayer class."
|
385
|
+
end
|
380
386
|
unless layer.is_a?(Layers::Layer) || layer.is_a?(Model)
|
381
387
|
raise TypeError, "layer: #{layer.class.name} is not an instance of the DNN::Layers::Layer class or DNN::Models::Model class."
|
382
388
|
end
|
@@ -386,13 +392,23 @@ module DNN
|
|
386
392
|
|
387
393
|
alias << add
|
388
394
|
|
395
|
+
# Insert layer to the model by index position.
|
396
|
+
# @param [DNN::Layers::Layer] layer Layer to add to the model.
|
397
|
+
# @return [DNN::Models::Model] Return self.
|
398
|
+
def insert(index, layer)
|
399
|
+
if layer.is_a?(MergeLayers::MergeLayer)
|
400
|
+
raise TypeError, "layer: #{layer.class.name} should not be a DNN::MergeLayers::MergeLayer class."
|
401
|
+
end
|
402
|
+
unless layer.is_a?(Layers::Layer) || layer.is_a?(Model)
|
403
|
+
raise TypeError, "layer: #{layer.class.name} is not an instance of the DNN::Layers::Layer class or DNN::Models::Model class."
|
404
|
+
end
|
405
|
+
@stack.insert(index, layer)
|
406
|
+
end
|
407
|
+
|
389
408
|
# Remove layer to the model.
|
390
409
|
# @param [DNN::Layers::Layer] layer Layer to remove to the model.
|
391
410
|
# @return [Boolean] Return true if success for remove layer.
|
392
411
|
def remove(layer)
|
393
|
-
unless layer.is_a?(Layers::Layer) || layer.is_a?(Model)
|
394
|
-
raise TypeError, "layer: #{layer.class.name} is not an instance of the DNN::Layers::Layer class or DNN::Models::Model class."
|
395
|
-
end
|
396
412
|
@stack.delete(layer) ? true : false
|
397
413
|
end
|
398
414
|
|
data/lib/dnn/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: ruby-dnn
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.14.
|
4
|
+
version: 0.14.3
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- unagiootoro
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2019-
|
11
|
+
date: 2019-11-03 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|
@@ -170,7 +170,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
170
170
|
- !ruby/object:Gem::Version
|
171
171
|
version: '0'
|
172
172
|
requirements: []
|
173
|
-
rubygems_version: 3.0.
|
173
|
+
rubygems_version: 3.0.3
|
174
174
|
signing_key:
|
175
175
|
specification_version: 4
|
176
176
|
summary: ruby deep learning library.
|