ruby-dnn 0.14.2 → 0.14.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 26d153a8ef243d2a069fd5b48a08e0216c28d8dd0307f570fce2ef146eef3cfe
4
- data.tar.gz: 9130ce3f31d2e7b5b3a3ecb48d9e4e10cddd307b2f900e13799206f8b16c3022
3
+ metadata.gz: 9e4633ce695dc370e62c6653d5ff07d81f9306625f7165750f8e455c4a28ca61
4
+ data.tar.gz: e4275720191b14592e2fb93581e9826e437ce9fcc75f6473cb963892d8f9f893
5
5
  SHA512:
6
- metadata.gz: f6167f81c1c7c76ae0539d17e9f9d252f03b603a15ee376cf91d0aa620d8d60aa5f434d3587f3331221679f661d22978207564d60a7ae59cf22551c482beb62c
7
- data.tar.gz: 0ab323a64dd7d68692c901257fa38981355c6533eb106a76ed1f4072c4d338bb4c5c72b0a4db2704c146ae49b827c75f45b79db924cdde92352f72f332bd7c9c
6
+ metadata.gz: d9831b1b3a73423742bce7f989a13c7ac3d63426a2bd8366143be1717a1e90fbdf6aa79a9e90508e193e00fbcd120942b7ccd0c8b6b13eb9cc4e237471712030
7
+ data.tar.gz: 60344f9eb3645b0560dd9d1f63a86dcc554f7c9b923a3b84a0bc301264120d50ca034d2887c1cc41989f223891ace4d64c570c9935e2b3f6f467a05f0fab910f
data/README.md CHANGED
@@ -55,7 +55,7 @@ class MLP < Model
55
55
  end
56
56
 
57
57
  def call(x)
58
- x = InputLayer.(x)
58
+ x = InputLayer.new(784).(x)
59
59
  x = @l1.(x)
60
60
  x = ReLU.(x)
61
61
  x = @l2.(x)
@@ -121,6 +121,6 @@ class DCGAN < Model
121
121
  label = Numo::SFloat.cast([1] * batch_size).reshape(batch_size, 1)
122
122
  dcgan_loss = train_on_batch(noise, label)
123
123
 
124
- { dis_loss: dis_loss.mean, dcgan_loss: dcgan_loss.mean }
124
+ { dis_loss: dis_loss, dcgan_loss: dcgan_loss }
125
125
  end
126
126
  end
@@ -19,11 +19,12 @@ dcgan = DCGAN.new(gen, dis)
19
19
 
20
20
  dis.setup(Adam.new(alpha: 0.00001, beta1: 0.1), SigmoidCrossEntropy.new)
21
21
  dcgan.setup(Adam.new(alpha: 0.0002, beta1: 0.5), SigmoidCrossEntropy.new)
22
+ dcgan.add_callback(CheckPoint.new("trained/dcgan_model"))
23
+ dcgan.predict1(Numo::SFloat.zeros(20))
22
24
 
23
25
  x_train, * = MNIST.load_train
24
26
  x_train = Numo::SFloat.cast(x_train)
25
27
  x_train = x_train / 127.5 - 1
26
28
 
27
- dcgan.add_callback(CheckPoint.new("trained/dcgan_model"))
28
- dcgan.predict1(Numo::SFloat.zeros(20))
29
- dcgan.train(x_train, x_train, epochs, batch_size: batch_size, last_round_down: true)
29
+ iter = DNN::Iterator.new(x_train, x_train, last_round_down: true)
30
+ dcgan.fit_by_iterator(iter, epochs, batch_size: batch_size)
@@ -32,7 +32,7 @@ class MLP < Model
32
32
  end
33
33
 
34
34
  def call(x)
35
- x = InputLayer.(x)
35
+ x = InputLayer.new(784).(x)
36
36
  x = @l1.(x)
37
37
  x = @bn1.(x)
38
38
  x = ReLU.(x)
@@ -385,6 +385,29 @@ module DNN
385
385
  end
386
386
  end
387
387
 
388
+ class GlobalAvgPool2D < Layer
389
+ def build(input_shape)
390
+ unless input_shape.length == 3
391
+ raise DNN_ShapeError, "Input shape is #{input_shape}. But input shape must be 3 dimensional."
392
+ end
393
+ super
394
+ @avg_pool2d = AvgPool2D.new(input_shape[0..1])
395
+ @avg_pool2d.build(input_shape)
396
+ @flatten = Flatten.new
397
+ @flatten.build([1, 1, input_shape[2]])
398
+ end
399
+
400
+ def forward(x)
401
+ y = @avg_pool2d.forward(x)
402
+ @flatten.forward(y)
403
+ end
404
+
405
+ def backward(dy)
406
+ dy = @flatten.backward(dy)
407
+ @avg_pool2d.backward(dy)
408
+ end
409
+ end
410
+
388
411
  class UnPool2D < Layer
389
412
  include Conv2DUtils
390
413
 
@@ -370,13 +370,19 @@ module DNN
370
370
  # @param [Array] stack All layers possessed by the model.
371
371
  def initialize(stack = [])
372
372
  super()
373
- @stack = stack.clone
373
+ @stack = []
374
+ stack.each do |layer|
375
+ add(layer)
376
+ end
374
377
  end
375
378
 
376
379
  # Add layer to the model.
377
380
  # @param [DNN::Layers::Layer] layer Layer to add to the model.
378
381
  # @return [DNN::Models::Model] Return self.
379
382
  def add(layer)
383
+ if layer.is_a?(MergeLayers::MergeLayer)
384
+ raise TypeError, "layer: #{layer.class.name} should not be a DNN::MergeLayers::MergeLayer class."
385
+ end
380
386
  unless layer.is_a?(Layers::Layer) || layer.is_a?(Model)
381
387
  raise TypeError, "layer: #{layer.class.name} is not an instance of the DNN::Layers::Layer class or DNN::Models::Model class."
382
388
  end
@@ -386,13 +392,23 @@ module DNN
386
392
 
387
393
  alias << add
388
394
 
395
+ # Insert layer to the model by index position.
396
+ # @param [DNN::Layers::Layer] layer Layer to add to the model.
397
+ # @return [DNN::Models::Model] Return self.
398
+ def insert(index, layer)
399
+ if layer.is_a?(MergeLayers::MergeLayer)
400
+ raise TypeError, "layer: #{layer.class.name} should not be a DNN::MergeLayers::MergeLayer class."
401
+ end
402
+ unless layer.is_a?(Layers::Layer) || layer.is_a?(Model)
403
+ raise TypeError, "layer: #{layer.class.name} is not an instance of the DNN::Layers::Layer class or DNN::Models::Model class."
404
+ end
405
+ @stack.insert(index, layer)
406
+ end
407
+
389
408
  # Remove layer to the model.
390
409
  # @param [DNN::Layers::Layer] layer Layer to remove to the model.
391
410
  # @return [Boolean] Return true if success for remove layer.
392
411
  def remove(layer)
393
- unless layer.is_a?(Layers::Layer) || layer.is_a?(Model)
394
- raise TypeError, "layer: #{layer.class.name} is not an instance of the DNN::Layers::Layer class or DNN::Models::Model class."
395
- end
396
412
  @stack.delete(layer) ? true : false
397
413
  end
398
414
 
data/lib/dnn/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module DNN
2
- VERSION = "0.14.2"
2
+ VERSION = "0.14.3"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: ruby-dnn
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.14.2
4
+ version: 0.14.3
5
5
  platform: ruby
6
6
  authors:
7
7
  - unagiootoro
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2019-10-26 00:00:00.000000000 Z
11
+ date: 2019-11-03 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -170,7 +170,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
170
170
  - !ruby/object:Gem::Version
171
171
  version: '0'
172
172
  requirements: []
173
- rubygems_version: 3.0.1
173
+ rubygems_version: 3.0.3
174
174
  signing_key:
175
175
  specification_version: 4
176
176
  summary: ruby deep learning library.