ruby-dnn 0.14.2 → 0.14.3

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 26d153a8ef243d2a069fd5b48a08e0216c28d8dd0307f570fce2ef146eef3cfe
4
- data.tar.gz: 9130ce3f31d2e7b5b3a3ecb48d9e4e10cddd307b2f900e13799206f8b16c3022
3
+ metadata.gz: 9e4633ce695dc370e62c6653d5ff07d81f9306625f7165750f8e455c4a28ca61
4
+ data.tar.gz: e4275720191b14592e2fb93581e9826e437ce9fcc75f6473cb963892d8f9f893
5
5
  SHA512:
6
- metadata.gz: f6167f81c1c7c76ae0539d17e9f9d252f03b603a15ee376cf91d0aa620d8d60aa5f434d3587f3331221679f661d22978207564d60a7ae59cf22551c482beb62c
7
- data.tar.gz: 0ab323a64dd7d68692c901257fa38981355c6533eb106a76ed1f4072c4d338bb4c5c72b0a4db2704c146ae49b827c75f45b79db924cdde92352f72f332bd7c9c
6
+ metadata.gz: d9831b1b3a73423742bce7f989a13c7ac3d63426a2bd8366143be1717a1e90fbdf6aa79a9e90508e193e00fbcd120942b7ccd0c8b6b13eb9cc4e237471712030
7
+ data.tar.gz: 60344f9eb3645b0560dd9d1f63a86dcc554f7c9b923a3b84a0bc301264120d50ca034d2887c1cc41989f223891ace4d64c570c9935e2b3f6f467a05f0fab910f
data/README.md CHANGED
@@ -55,7 +55,7 @@ class MLP < Model
55
55
  end
56
56
 
57
57
  def call(x)
58
- x = InputLayer.(x)
58
+ x = InputLayer.new(784).(x)
59
59
  x = @l1.(x)
60
60
  x = ReLU.(x)
61
61
  x = @l2.(x)
@@ -121,6 +121,6 @@ class DCGAN < Model
121
121
  label = Numo::SFloat.cast([1] * batch_size).reshape(batch_size, 1)
122
122
  dcgan_loss = train_on_batch(noise, label)
123
123
 
124
- { dis_loss: dis_loss.mean, dcgan_loss: dcgan_loss.mean }
124
+ { dis_loss: dis_loss, dcgan_loss: dcgan_loss }
125
125
  end
126
126
  end
@@ -19,11 +19,12 @@ dcgan = DCGAN.new(gen, dis)
19
19
 
20
20
  dis.setup(Adam.new(alpha: 0.00001, beta1: 0.1), SigmoidCrossEntropy.new)
21
21
  dcgan.setup(Adam.new(alpha: 0.0002, beta1: 0.5), SigmoidCrossEntropy.new)
22
+ dcgan.add_callback(CheckPoint.new("trained/dcgan_model"))
23
+ dcgan.predict1(Numo::SFloat.zeros(20))
22
24
 
23
25
  x_train, * = MNIST.load_train
24
26
  x_train = Numo::SFloat.cast(x_train)
25
27
  x_train = x_train / 127.5 - 1
26
28
 
27
- dcgan.add_callback(CheckPoint.new("trained/dcgan_model"))
28
- dcgan.predict1(Numo::SFloat.zeros(20))
29
- dcgan.train(x_train, x_train, epochs, batch_size: batch_size, last_round_down: true)
29
+ iter = DNN::Iterator.new(x_train, x_train, last_round_down: true)
30
+ dcgan.fit_by_iterator(iter, epochs, batch_size: batch_size)
@@ -32,7 +32,7 @@ class MLP < Model
32
32
  end
33
33
 
34
34
  def call(x)
35
- x = InputLayer.(x)
35
+ x = InputLayer.new(784).(x)
36
36
  x = @l1.(x)
37
37
  x = @bn1.(x)
38
38
  x = ReLU.(x)
@@ -385,6 +385,29 @@ module DNN
385
385
  end
386
386
  end
387
387
 
388
+ class GlobalAvgPool2D < Layer
389
+ def build(input_shape)
390
+ unless input_shape.length == 3
391
+ raise DNN_ShapeError, "Input shape is #{input_shape}. But input shape must be 3 dimensional."
392
+ end
393
+ super
394
+ @avg_pool2d = AvgPool2D.new(input_shape[0..1])
395
+ @avg_pool2d.build(input_shape)
396
+ @flatten = Flatten.new
397
+ @flatten.build([1, 1, input_shape[2]])
398
+ end
399
+
400
+ def forward(x)
401
+ y = @avg_pool2d.forward(x)
402
+ @flatten.forward(y)
403
+ end
404
+
405
+ def backward(dy)
406
+ dy = @flatten.backward(dy)
407
+ @avg_pool2d.backward(dy)
408
+ end
409
+ end
410
+
388
411
  class UnPool2D < Layer
389
412
  include Conv2DUtils
390
413
 
@@ -370,13 +370,19 @@ module DNN
370
370
  # @param [Array] stack All layers possessed by the model.
371
371
  def initialize(stack = [])
372
372
  super()
373
- @stack = stack.clone
373
+ @stack = []
374
+ stack.each do |layer|
375
+ add(layer)
376
+ end
374
377
  end
375
378
 
376
379
  # Add layer to the model.
377
380
  # @param [DNN::Layers::Layer] layer Layer to add to the model.
378
381
  # @return [DNN::Models::Model] Return self.
379
382
  def add(layer)
383
+ if layer.is_a?(MergeLayers::MergeLayer)
384
+ raise TypeError, "layer: #{layer.class.name} should not be a DNN::MergeLayers::MergeLayer class."
385
+ end
380
386
  unless layer.is_a?(Layers::Layer) || layer.is_a?(Model)
381
387
  raise TypeError, "layer: #{layer.class.name} is not an instance of the DNN::Layers::Layer class or DNN::Models::Model class."
382
388
  end
@@ -386,13 +392,23 @@ module DNN
386
392
 
387
393
  alias << add
388
394
 
395
+ # Insert layer to the model by index position.
396
+ # @param [DNN::Layers::Layer] layer Layer to add to the model.
397
+ # @return [DNN::Models::Model] Return self.
398
+ def insert(index, layer)
399
+ if layer.is_a?(MergeLayers::MergeLayer)
400
+ raise TypeError, "layer: #{layer.class.name} should not be a DNN::MergeLayers::MergeLayer class."
401
+ end
402
+ unless layer.is_a?(Layers::Layer) || layer.is_a?(Model)
403
+ raise TypeError, "layer: #{layer.class.name} is not an instance of the DNN::Layers::Layer class or DNN::Models::Model class."
404
+ end
405
+ @stack.insert(index, layer)
406
+ end
407
+
389
408
  # Remove layer to the model.
390
409
  # @param [DNN::Layers::Layer] layer Layer to remove to the model.
391
410
  # @return [Boolean] Return true if success for remove layer.
392
411
  def remove(layer)
393
- unless layer.is_a?(Layers::Layer) || layer.is_a?(Model)
394
- raise TypeError, "layer: #{layer.class.name} is not an instance of the DNN::Layers::Layer class or DNN::Models::Model class."
395
- end
396
412
  @stack.delete(layer) ? true : false
397
413
  end
398
414
 
data/lib/dnn/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module DNN
2
- VERSION = "0.14.2"
2
+ VERSION = "0.14.3"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: ruby-dnn
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.14.2
4
+ version: 0.14.3
5
5
  platform: ruby
6
6
  authors:
7
7
  - unagiootoro
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2019-10-26 00:00:00.000000000 Z
11
+ date: 2019-11-03 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -170,7 +170,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
170
170
  - !ruby/object:Gem::Version
171
171
  version: '0'
172
172
  requirements: []
173
- rubygems_version: 3.0.1
173
+ rubygems_version: 3.0.3
174
174
  signing_key:
175
175
  specification_version: 4
176
176
  summary: ruby deep learning library.