ruby-dnn 0.16.1 → 0.16.2

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 2e04eef16303e2e223bfdb1ac17e2a35fc20d0eafff022f71e412d8b41dc49c9
4
- data.tar.gz: d0cd83a209069eba6ea2c48fc6c3480e5edf5b99932b31cb9d7a73adb0eb573d
3
+ metadata.gz: b76b0eb5bf75a22e48726f93fa4faff413c196b1f1587ce145aba7042c84a532
4
+ data.tar.gz: f32d09cb89391583f51a557e5f72024f618d943af3becc25f8f19905fa395b3b
5
5
  SHA512:
6
- metadata.gz: 2b7126723c81495c603a1b98fc64fdd111d4853db78e4f0c840f75e925d60d97a3dd6ba22a9a9e2a8f643562a1284bfe502e38e0e9e4870c7699f16ed151df40
7
- data.tar.gz: c1c8900e92ec77a03f46853a9ff14d4b09c8442a597f6ae87d1ee049764a3316e1b7ff92cdb4c13c4b2787d2759ff5c8f6c28c11668503d1d0062a2f671a04b7
6
+ metadata.gz: f70784c49f71420df424c2077b430c0cb837b31b5493f62831cf49b296d834045796f2bae7749fa3b4bfb3ca365a97da856ccbebcc666c1489020eb55ec408fc
7
+ data.tar.gz: 52fd5cf850341b3d2bdd44da35f69254e508b0dd77e052a7217fe250a8d41a00d237b5ce0cb379e652bf2e6b7d0b99e95b3530ef5abe2d43a2b8f8c8f7dbca7e
@@ -0,0 +1,118 @@
1
+ require "dnn"
2
+ require "dnn/datasets/mnist"
3
+ require "dnn/image"
4
+ require "numo/linalg/autoloader"
5
+
6
+ include DNN::Models
7
+ include DNN::Layers
8
+ include DNN::Optimizers
9
+ include DNN::Losses
10
+
11
+ x_train, y_train = DNN::MNIST.load_train
12
+ x_test, y_test = DNN::MNIST.load_test
13
+
14
+ x_train = Numo::SFloat.cast(x_train).reshape(x_train.shape[0], 784)
15
+ x_test = Numo::SFloat.cast(x_test).reshape(x_test.shape[0], 784)
16
+
17
+ x_train /= 255
18
+ x_test /= 255
19
+
20
+ y_train = DNN::Utils.to_categorical(y_train, 10, Numo::SFloat)
21
+ y_test = DNN::Utils.to_categorical(y_test, 10, Numo::SFloat)
22
+
23
+ $z_dim = 2
24
+ $z_mean = nil
25
+ $z_sigma = nil
26
+
27
+ class Sampling < MergeLayer
28
+ def forward(z_mean, z_sigma)
29
+ epsilon = DNN::Tensor.new(Numo::SFloat.new($z_dim).rand_norm(0, 1))
30
+ Tanh.(z_mean + z_sigma * epsilon)
31
+ end
32
+ end
33
+
34
+ class Encoder < Model
35
+ def initialize
36
+ super
37
+ @l1 = Dense.new(196)
38
+ @l2 = Dense.new(49)
39
+ @l3_1 = Dense.new($z_dim)
40
+ @l3_2 = Dense.new($z_dim)
41
+ @bn1 = BatchNormalization.new
42
+ @bn2 = BatchNormalization.new
43
+ end
44
+
45
+ def forward(x)
46
+ x = InputLayer.new(784).(x)
47
+ x = @l1.(x)
48
+ x = @bn1.(x)
49
+ x = ReLU.(x)
50
+ x = @l2.(x)
51
+ x = @bn2.(x)
52
+ x = ReLU.(x)
53
+ z_mean = @l3_1.(x)
54
+ z_sigma = @l3_2.(x)
55
+ [z_mean, z_sigma]
56
+ end
57
+ end
58
+
59
+ class Decoder < Model
60
+ def initialize
61
+ super
62
+ @l3 = Dense.new(196)
63
+ @l4 = Dense.new(784)
64
+ @bn1 = BatchNormalization.new
65
+ end
66
+
67
+ def forward(z)
68
+ x = @l3.(z)
69
+ x = @bn1.(x)
70
+ x = ReLU.(x)
71
+ x = @l4.(x)
72
+ x
73
+ end
74
+ end
75
+
76
+ class VAE < Model
77
+ attr_accessor :enc
78
+ attr_accessor :dec
79
+
80
+ def initialize(enc = nil, dec = nil)
81
+ super()
82
+ @enc = enc || Encoder.new
83
+ @dec = dec || Decoder.new
84
+ end
85
+
86
+ def forward(x)
87
+ z_mean, z_sigma = @enc.(x)
88
+ $z_mean, $z_sigma = z_mean, z_sigma
89
+ z = Sampling.(z_mean, z_sigma)
90
+ x = @dec.(z)
91
+ x
92
+ end
93
+ end
94
+
95
+ class VAELoss < Loss
96
+ def forward(y, t)
97
+ kl = -0.5 * Mean.(Sum.(1 + Log.($z_sigma**2) - $z_mean**2 - $z_sigma**2, axis: 1), axis: 0)
98
+ SigmoidCrossEntropy.(y, t) + kl
99
+ end
100
+ end
101
+
102
+ model = VAE.new
103
+ dec = model.dec
104
+ model.setup(Adam.new, VAELoss.new)
105
+
106
+ model.train(x_train, x_train, 10, batch_size: 100)
107
+
108
+ images = []
109
+ 10.times do |i|
110
+ 10.times do |j|
111
+ z1 = (i / 4.5) - 1
112
+ z2 = (j / 4.5) - 1
113
+ z = Numo::SFloat[z1, z2]
114
+ out = DNN::Utils.sigmoid(dec.predict1(z))
115
+ img = Numo::UInt8.cast(out * 255).reshape(28, 28, 1)
116
+ DNN::Image.write("img/img_#{i}_#{j}.png", img)
117
+ end
118
+ end
@@ -2,6 +2,8 @@ module DNN
2
2
  module Layers
3
3
 
4
4
  class Add < MergeLayer
5
+ include MergeLayerNode
6
+
5
7
  def forward_node(x1, x2)
6
8
  x1 + x2
7
9
  end
@@ -12,6 +14,8 @@ module DNN
12
14
  end
13
15
 
14
16
  class Sub < MergeLayer
17
+ include MergeLayerNode
18
+
15
19
  def forward_node(x1, x2)
16
20
  x1 - x2
17
21
  end
@@ -22,6 +26,8 @@ module DNN
22
26
  end
23
27
 
24
28
  class Mul < MergeLayer
29
+ include MergeLayerNode
30
+
25
31
  def forward_node(x1, x2)
26
32
  @x1, @x2 = x1, x2
27
33
  x1 * x2
@@ -33,6 +39,8 @@ module DNN
33
39
  end
34
40
 
35
41
  class Div < MergeLayer
42
+ include MergeLayerNode
43
+
36
44
  def forward_node(x1, x2)
37
45
  @x1, @x2 = x1, x2
38
46
  x1 / x2
@@ -46,6 +54,8 @@ module DNN
46
54
  end
47
55
 
48
56
  class Dot < MergeLayer
57
+ include MergeLayerNode
58
+
49
59
  def forward_node(x1, x2)
50
60
  @x1, @x2 = x1, x2
51
61
  x1.dot(x2)
@@ -60,12 +70,11 @@ module DNN
60
70
  include LayerNode
61
71
 
62
72
  def forward_node(x)
63
- @x = x
64
- Xumo::NMath.exp(x)
73
+ @y = Xumo::NMath.exp(x)
65
74
  end
66
75
 
67
76
  def backward_node(dy)
68
- dy * Xumo::NMath.exp(@x)
77
+ dy * @y
69
78
  end
70
79
  end
71
80
 
@@ -122,20 +131,16 @@ module DNN
122
131
  end
123
132
 
124
133
  def forward_node(x)
125
- if @axis
126
- @dim = x.shape[@axis]
127
- x.sum(axis: @axis, keepdims: true)
128
- else
129
- x.sum
130
- end
134
+ @x_shape = x.shape
135
+ @dim = x.shape[@axis]
136
+ x.sum(axis: @axis, keepdims: true)
131
137
  end
132
138
 
133
139
  def backward_node(dy)
134
- dx = dy.clone
135
- if @axis
136
- (@dim - 1).times do
137
- dx = dx.concatenate(dy, axis: @axis)
138
- end
140
+ return dy if @x_shape == dy.shape
141
+ dx = dy
142
+ (@dim - 1).times do
143
+ dx = dx.concatenate(dy, axis: @axis)
139
144
  end
140
145
  dx
141
146
  end
@@ -150,16 +155,16 @@ module DNN
150
155
  end
151
156
 
152
157
  def forward_node(x)
153
- @dim = @axis ? x.shape[@axis] : x.size
158
+ @x_shape = x.shape
159
+ @dim = x.shape[@axis]
154
160
  x.mean(axis: @axis, keepdims: true)
155
161
  end
156
162
 
157
163
  def backward_node(dy)
164
+ return dy / @dim if @x_shape == dy.shape
158
165
  dx = dy
159
- if @axis
160
- (@dim - 1).times do
161
- dx = dx.concatenate(dy, axis: @axis)
162
- end
166
+ (@dim - 1).times do
167
+ dx = dx.concatenate(dy, axis: @axis)
163
168
  end
164
169
  dx / @dim
165
170
  end
@@ -26,8 +26,6 @@ module DNN
26
26
  end
27
27
 
28
28
  class MergeLayer < Layers::Layer
29
- include MergeLayerNode
30
-
31
29
  def self.call(x1, x2, *args)
32
30
  new(*args).call(x1, x2)
33
31
  end
@@ -45,6 +43,8 @@ module DNN
45
43
  end
46
44
 
47
45
  class Concatenate < MergeLayer
46
+ include MergeLayerNode
47
+
48
48
  attr_reader :axis
49
49
 
50
50
  def initialize(axis: 1)
@@ -75,8 +75,8 @@ module DNN
75
75
  0.5 * ((y - t)**2).mean(0).sum
76
76
  end
77
77
 
78
- def backward_node(dy)
79
- @y - @t
78
+ def backward_node(d)
79
+ (@y - @t) / @y.shape[0]
80
80
  end
81
81
  end
82
82
 
@@ -90,10 +90,10 @@ module DNN
90
90
  end
91
91
 
92
92
  def backward_node(d)
93
- dy = @y - @t
93
+ dy = (@y - @t)
94
94
  dy[dy >= 0] = 1
95
95
  dy[dy < 0] = -1
96
- dy
96
+ dy / @y.shape[0]
97
97
  end
98
98
  end
99
99
 
@@ -109,7 +109,7 @@ module DNN
109
109
  def backward_node(d)
110
110
  a = Xumo::SFloat.ones(*@a.shape)
111
111
  a[@a <= 0] = 0
112
- a * -@t
112
+ (a * -@t) / a.shape[0]
113
113
  end
114
114
  end
115
115
 
@@ -124,12 +124,12 @@ module DNN
124
124
  end
125
125
 
126
126
  def backward_node(d)
127
- dy = @y - @t
127
+ dy = (@y - @t)
128
128
  if @loss_value > 1
129
129
  dy[dy >= 0] = 1
130
130
  dy[dy < 0] = -1
131
131
  end
132
- dy
132
+ dy / @y.shape[0]
133
133
  end
134
134
 
135
135
  private
@@ -168,7 +168,7 @@ module DNN
168
168
  end
169
169
 
170
170
  def backward_node(d)
171
- @x - @t
171
+ (@x - @t) / @x.shape[0]
172
172
  end
173
173
 
174
174
  def to_hash
@@ -205,7 +205,7 @@ module DNN
205
205
  end
206
206
 
207
207
  def backward_node(d)
208
- @x - @t
208
+ (@x - @t) / @x.shape[0]
209
209
  end
210
210
 
211
211
  def to_hash
@@ -262,7 +262,7 @@ module DNN
262
262
  DNN.learning_phase = true
263
263
  out = call(Tensor.convert(x))
264
264
  loss = @loss_func.loss(out, Tensor.convert(y), layers)
265
- loss.link.backward(Xumo::SFloat.zeros(y[0...1, false].shape))
265
+ loss.link.backward(Xumo::SFloat.ones(y[0...1, false].shape[0], 1))
266
266
  @optimizer.update(get_all_trainable_params)
267
267
  @last_log[:train_loss] = loss.data
268
268
  call_callbacks(:after_train_on_batch)
@@ -353,7 +353,12 @@ module DNN
353
353
  # @param [Numo::SFloat] x Input data. However, x is single data.
354
354
  def predict1(x, use_loss_activation: true)
355
355
  check_xy_type(x)
356
- predict(x.reshape(1, *x.shape), use_loss_activation: use_loss_activation)[0, false]
356
+ input = if x.is_a?(Array)
357
+ x.map { |v| v.reshape(1, *v.shape) }
358
+ else
359
+ x.reshape(1, *x.shape)
360
+ end
361
+ predict(input, use_loss_activation: use_loss_activation)[0, false]
357
362
  end
358
363
 
359
364
  # Add callback function.
@@ -458,7 +463,7 @@ module DNN
458
463
  if !x.is_a?(Xumo::SFloat) && !x.is_a?(Array)
459
464
  raise TypeError, "x:#{x.class.name} is not an instance of #{Xumo::SFloat.name} class or Array class."
460
465
  end
461
- if y && !y.is_a?(Xumo::SFloat) && !x.is_a?(Array)
466
+ if y && !y.is_a?(Xumo::SFloat) && !y.is_a?(Array)
462
467
  raise TypeError, "y:#{y.class.name} is not an instance of #{Xumo::SFloat.name} class or Array class."
463
468
  end
464
469
  end
@@ -1,3 +1,3 @@
1
1
  module DNN
2
- VERSION = "0.16.1"
2
+ VERSION = "0.16.2"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: ruby-dnn
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.16.1
4
+ version: 0.16.2
5
5
  platform: ruby
6
6
  authors:
7
7
  - unagiootoro
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2020-01-06 00:00:00.000000000 Z
11
+ date: 2020-01-11 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -114,6 +114,7 @@ files:
114
114
  - examples/pix2pix/dcgan.rb
115
115
  - examples/pix2pix/imgen.rb
116
116
  - examples/pix2pix/train.rb
117
+ - examples/vae.rb
117
118
  - examples/xor_example.rb
118
119
  - ext/rb_stb_image/extconf.rb
119
120
  - ext/rb_stb_image/rb_stb_image.c