ruby-dnn 0.16.1 → 0.16.2
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/examples/vae.rb +118 -0
- data/lib/dnn/core/layers/math_layers.rb +24 -19
- data/lib/dnn/core/layers/merge_layers.rb +2 -2
- data/lib/dnn/core/losses.rb +9 -9
- data/lib/dnn/core/models.rb +8 -3
- data/lib/dnn/version.rb +1 -1
- metadata +3 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: b76b0eb5bf75a22e48726f93fa4faff413c196b1f1587ce145aba7042c84a532
|
4
|
+
data.tar.gz: f32d09cb89391583f51a557e5f72024f618d943af3becc25f8f19905fa395b3b
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: f70784c49f71420df424c2077b430c0cb837b31b5493f62831cf49b296d834045796f2bae7749fa3b4bfb3ca365a97da856ccbebcc666c1489020eb55ec408fc
|
7
|
+
data.tar.gz: 52fd5cf850341b3d2bdd44da35f69254e508b0dd77e052a7217fe250a8d41a00d237b5ce0cb379e652bf2e6b7d0b99e95b3530ef5abe2d43a2b8f8c8f7dbca7e
|
data/examples/vae.rb
ADDED
@@ -0,0 +1,118 @@
|
|
1
|
+
require "dnn"
|
2
|
+
require "dnn/datasets/mnist"
|
3
|
+
require "dnn/image"
|
4
|
+
require "numo/linalg/autoloader"
|
5
|
+
|
6
|
+
include DNN::Models
|
7
|
+
include DNN::Layers
|
8
|
+
include DNN::Optimizers
|
9
|
+
include DNN::Losses
|
10
|
+
|
11
|
+
x_train, y_train = DNN::MNIST.load_train
|
12
|
+
x_test, y_test = DNN::MNIST.load_test
|
13
|
+
|
14
|
+
x_train = Numo::SFloat.cast(x_train).reshape(x_train.shape[0], 784)
|
15
|
+
x_test = Numo::SFloat.cast(x_test).reshape(x_test.shape[0], 784)
|
16
|
+
|
17
|
+
x_train /= 255
|
18
|
+
x_test /= 255
|
19
|
+
|
20
|
+
y_train = DNN::Utils.to_categorical(y_train, 10, Numo::SFloat)
|
21
|
+
y_test = DNN::Utils.to_categorical(y_test, 10, Numo::SFloat)
|
22
|
+
|
23
|
+
$z_dim = 2
|
24
|
+
$z_mean = nil
|
25
|
+
$z_sigma = nil
|
26
|
+
|
27
|
+
class Sampling < MergeLayer
|
28
|
+
def forward(z_mean, z_sigma)
|
29
|
+
epsilon = DNN::Tensor.new(Numo::SFloat.new($z_dim).rand_norm(0, 1))
|
30
|
+
Tanh.(z_mean + z_sigma * epsilon)
|
31
|
+
end
|
32
|
+
end
|
33
|
+
|
34
|
+
class Encoder < Model
|
35
|
+
def initialize
|
36
|
+
super
|
37
|
+
@l1 = Dense.new(196)
|
38
|
+
@l2 = Dense.new(49)
|
39
|
+
@l3_1 = Dense.new($z_dim)
|
40
|
+
@l3_2 = Dense.new($z_dim)
|
41
|
+
@bn1 = BatchNormalization.new
|
42
|
+
@bn2 = BatchNormalization.new
|
43
|
+
end
|
44
|
+
|
45
|
+
def forward(x)
|
46
|
+
x = InputLayer.new(784).(x)
|
47
|
+
x = @l1.(x)
|
48
|
+
x = @bn1.(x)
|
49
|
+
x = ReLU.(x)
|
50
|
+
x = @l2.(x)
|
51
|
+
x = @bn2.(x)
|
52
|
+
x = ReLU.(x)
|
53
|
+
z_mean = @l3_1.(x)
|
54
|
+
z_sigma = @l3_2.(x)
|
55
|
+
[z_mean, z_sigma]
|
56
|
+
end
|
57
|
+
end
|
58
|
+
|
59
|
+
class Decoder < Model
|
60
|
+
def initialize
|
61
|
+
super
|
62
|
+
@l3 = Dense.new(196)
|
63
|
+
@l4 = Dense.new(784)
|
64
|
+
@bn1 = BatchNormalization.new
|
65
|
+
end
|
66
|
+
|
67
|
+
def forward(z)
|
68
|
+
x = @l3.(z)
|
69
|
+
x = @bn1.(x)
|
70
|
+
x = ReLU.(x)
|
71
|
+
x = @l4.(x)
|
72
|
+
x
|
73
|
+
end
|
74
|
+
end
|
75
|
+
|
76
|
+
class VAE < Model
|
77
|
+
attr_accessor :enc
|
78
|
+
attr_accessor :dec
|
79
|
+
|
80
|
+
def initialize(enc = nil, dec = nil)
|
81
|
+
super()
|
82
|
+
@enc = enc || Encoder.new
|
83
|
+
@dec = dec || Decoder.new
|
84
|
+
end
|
85
|
+
|
86
|
+
def forward(x)
|
87
|
+
z_mean, z_sigma = @enc.(x)
|
88
|
+
$z_mean, $z_sigma = z_mean, z_sigma
|
89
|
+
z = Sampling.(z_mean, z_sigma)
|
90
|
+
x = @dec.(z)
|
91
|
+
x
|
92
|
+
end
|
93
|
+
end
|
94
|
+
|
95
|
+
class VAELoss < Loss
|
96
|
+
def forward(y, t)
|
97
|
+
kl = -0.5 * Mean.(Sum.(1 + Log.($z_sigma**2) - $z_mean**2 - $z_sigma**2, axis: 1), axis: 0)
|
98
|
+
SigmoidCrossEntropy.(y, t) + kl
|
99
|
+
end
|
100
|
+
end
|
101
|
+
|
102
|
+
model = VAE.new
|
103
|
+
dec = model.dec
|
104
|
+
model.setup(Adam.new, VAELoss.new)
|
105
|
+
|
106
|
+
model.train(x_train, x_train, 10, batch_size: 100)
|
107
|
+
|
108
|
+
images = []
|
109
|
+
10.times do |i|
|
110
|
+
10.times do |j|
|
111
|
+
z1 = (i / 4.5) - 1
|
112
|
+
z2 = (j / 4.5) - 1
|
113
|
+
z = Numo::SFloat[z1, z2]
|
114
|
+
out = DNN::Utils.sigmoid(dec.predict1(z))
|
115
|
+
img = Numo::UInt8.cast(out * 255).reshape(28, 28, 1)
|
116
|
+
DNN::Image.write("img/img_#{i}_#{j}.png", img)
|
117
|
+
end
|
118
|
+
end
|
@@ -2,6 +2,8 @@ module DNN
|
|
2
2
|
module Layers
|
3
3
|
|
4
4
|
class Add < MergeLayer
|
5
|
+
include MergeLayerNode
|
6
|
+
|
5
7
|
def forward_node(x1, x2)
|
6
8
|
x1 + x2
|
7
9
|
end
|
@@ -12,6 +14,8 @@ module DNN
|
|
12
14
|
end
|
13
15
|
|
14
16
|
class Sub < MergeLayer
|
17
|
+
include MergeLayerNode
|
18
|
+
|
15
19
|
def forward_node(x1, x2)
|
16
20
|
x1 - x2
|
17
21
|
end
|
@@ -22,6 +26,8 @@ module DNN
|
|
22
26
|
end
|
23
27
|
|
24
28
|
class Mul < MergeLayer
|
29
|
+
include MergeLayerNode
|
30
|
+
|
25
31
|
def forward_node(x1, x2)
|
26
32
|
@x1, @x2 = x1, x2
|
27
33
|
x1 * x2
|
@@ -33,6 +39,8 @@ module DNN
|
|
33
39
|
end
|
34
40
|
|
35
41
|
class Div < MergeLayer
|
42
|
+
include MergeLayerNode
|
43
|
+
|
36
44
|
def forward_node(x1, x2)
|
37
45
|
@x1, @x2 = x1, x2
|
38
46
|
x1 / x2
|
@@ -46,6 +54,8 @@ module DNN
|
|
46
54
|
end
|
47
55
|
|
48
56
|
class Dot < MergeLayer
|
57
|
+
include MergeLayerNode
|
58
|
+
|
49
59
|
def forward_node(x1, x2)
|
50
60
|
@x1, @x2 = x1, x2
|
51
61
|
x1.dot(x2)
|
@@ -60,12 +70,11 @@ module DNN
|
|
60
70
|
include LayerNode
|
61
71
|
|
62
72
|
def forward_node(x)
|
63
|
-
@
|
64
|
-
Xumo::NMath.exp(x)
|
73
|
+
@y = Xumo::NMath.exp(x)
|
65
74
|
end
|
66
75
|
|
67
76
|
def backward_node(dy)
|
68
|
-
dy *
|
77
|
+
dy * @y
|
69
78
|
end
|
70
79
|
end
|
71
80
|
|
@@ -122,20 +131,16 @@ module DNN
|
|
122
131
|
end
|
123
132
|
|
124
133
|
def forward_node(x)
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
else
|
129
|
-
x.sum
|
130
|
-
end
|
134
|
+
@x_shape = x.shape
|
135
|
+
@dim = x.shape[@axis]
|
136
|
+
x.sum(axis: @axis, keepdims: true)
|
131
137
|
end
|
132
138
|
|
133
139
|
def backward_node(dy)
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
end
|
140
|
+
return dy if @x_shape == dy.shape
|
141
|
+
dx = dy
|
142
|
+
(@dim - 1).times do
|
143
|
+
dx = dx.concatenate(dy, axis: @axis)
|
139
144
|
end
|
140
145
|
dx
|
141
146
|
end
|
@@ -150,16 +155,16 @@ module DNN
|
|
150
155
|
end
|
151
156
|
|
152
157
|
def forward_node(x)
|
153
|
-
@
|
158
|
+
@x_shape = x.shape
|
159
|
+
@dim = x.shape[@axis]
|
154
160
|
x.mean(axis: @axis, keepdims: true)
|
155
161
|
end
|
156
162
|
|
157
163
|
def backward_node(dy)
|
164
|
+
return dy / @dim if @x_shape == dy.shape
|
158
165
|
dx = dy
|
159
|
-
|
160
|
-
(
|
161
|
-
dx = dx.concatenate(dy, axis: @axis)
|
162
|
-
end
|
166
|
+
(@dim - 1).times do
|
167
|
+
dx = dx.concatenate(dy, axis: @axis)
|
163
168
|
end
|
164
169
|
dx / @dim
|
165
170
|
end
|
@@ -26,8 +26,6 @@ module DNN
|
|
26
26
|
end
|
27
27
|
|
28
28
|
class MergeLayer < Layers::Layer
|
29
|
-
include MergeLayerNode
|
30
|
-
|
31
29
|
def self.call(x1, x2, *args)
|
32
30
|
new(*args).call(x1, x2)
|
33
31
|
end
|
@@ -45,6 +43,8 @@ module DNN
|
|
45
43
|
end
|
46
44
|
|
47
45
|
class Concatenate < MergeLayer
|
46
|
+
include MergeLayerNode
|
47
|
+
|
48
48
|
attr_reader :axis
|
49
49
|
|
50
50
|
def initialize(axis: 1)
|
data/lib/dnn/core/losses.rb
CHANGED
@@ -75,8 +75,8 @@ module DNN
|
|
75
75
|
0.5 * ((y - t)**2).mean(0).sum
|
76
76
|
end
|
77
77
|
|
78
|
-
def backward_node(
|
79
|
-
@y - @t
|
78
|
+
def backward_node(d)
|
79
|
+
(@y - @t) / @y.shape[0]
|
80
80
|
end
|
81
81
|
end
|
82
82
|
|
@@ -90,10 +90,10 @@ module DNN
|
|
90
90
|
end
|
91
91
|
|
92
92
|
def backward_node(d)
|
93
|
-
dy = @y - @t
|
93
|
+
dy = (@y - @t)
|
94
94
|
dy[dy >= 0] = 1
|
95
95
|
dy[dy < 0] = -1
|
96
|
-
dy
|
96
|
+
dy / @y.shape[0]
|
97
97
|
end
|
98
98
|
end
|
99
99
|
|
@@ -109,7 +109,7 @@ module DNN
|
|
109
109
|
def backward_node(d)
|
110
110
|
a = Xumo::SFloat.ones(*@a.shape)
|
111
111
|
a[@a <= 0] = 0
|
112
|
-
a * -@t
|
112
|
+
(a * -@t) / a.shape[0]
|
113
113
|
end
|
114
114
|
end
|
115
115
|
|
@@ -124,12 +124,12 @@ module DNN
|
|
124
124
|
end
|
125
125
|
|
126
126
|
def backward_node(d)
|
127
|
-
dy = @y - @t
|
127
|
+
dy = (@y - @t)
|
128
128
|
if @loss_value > 1
|
129
129
|
dy[dy >= 0] = 1
|
130
130
|
dy[dy < 0] = -1
|
131
131
|
end
|
132
|
-
dy
|
132
|
+
dy / @y.shape[0]
|
133
133
|
end
|
134
134
|
|
135
135
|
private
|
@@ -168,7 +168,7 @@ module DNN
|
|
168
168
|
end
|
169
169
|
|
170
170
|
def backward_node(d)
|
171
|
-
@x - @t
|
171
|
+
(@x - @t) / @x.shape[0]
|
172
172
|
end
|
173
173
|
|
174
174
|
def to_hash
|
@@ -205,7 +205,7 @@ module DNN
|
|
205
205
|
end
|
206
206
|
|
207
207
|
def backward_node(d)
|
208
|
-
@x - @t
|
208
|
+
(@x - @t) / @x.shape[0]
|
209
209
|
end
|
210
210
|
|
211
211
|
def to_hash
|
data/lib/dnn/core/models.rb
CHANGED
@@ -262,7 +262,7 @@ module DNN
|
|
262
262
|
DNN.learning_phase = true
|
263
263
|
out = call(Tensor.convert(x))
|
264
264
|
loss = @loss_func.loss(out, Tensor.convert(y), layers)
|
265
|
-
loss.link.backward(Xumo::SFloat.
|
265
|
+
loss.link.backward(Xumo::SFloat.ones(y[0...1, false].shape[0], 1))
|
266
266
|
@optimizer.update(get_all_trainable_params)
|
267
267
|
@last_log[:train_loss] = loss.data
|
268
268
|
call_callbacks(:after_train_on_batch)
|
@@ -353,7 +353,12 @@ module DNN
|
|
353
353
|
# @param [Numo::SFloat] x Input data. However, x is single data.
|
354
354
|
def predict1(x, use_loss_activation: true)
|
355
355
|
check_xy_type(x)
|
356
|
-
|
356
|
+
input = if x.is_a?(Array)
|
357
|
+
x.map { |v| v.reshape(1, *v.shape) }
|
358
|
+
else
|
359
|
+
x.reshape(1, *x.shape)
|
360
|
+
end
|
361
|
+
predict(input, use_loss_activation: use_loss_activation)[0, false]
|
357
362
|
end
|
358
363
|
|
359
364
|
# Add callback function.
|
@@ -458,7 +463,7 @@ module DNN
|
|
458
463
|
if !x.is_a?(Xumo::SFloat) && !x.is_a?(Array)
|
459
464
|
raise TypeError, "x:#{x.class.name} is not an instance of #{Xumo::SFloat.name} class or Array class."
|
460
465
|
end
|
461
|
-
if y && !y.is_a?(Xumo::SFloat) && !
|
466
|
+
if y && !y.is_a?(Xumo::SFloat) && !y.is_a?(Array)
|
462
467
|
raise TypeError, "y:#{y.class.name} is not an instance of #{Xumo::SFloat.name} class or Array class."
|
463
468
|
end
|
464
469
|
end
|
data/lib/dnn/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: ruby-dnn
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.16.
|
4
|
+
version: 0.16.2
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- unagiootoro
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2020-01-
|
11
|
+
date: 2020-01-11 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|
@@ -114,6 +114,7 @@ files:
|
|
114
114
|
- examples/pix2pix/dcgan.rb
|
115
115
|
- examples/pix2pix/imgen.rb
|
116
116
|
- examples/pix2pix/train.rb
|
117
|
+
- examples/vae.rb
|
117
118
|
- examples/xor_example.rb
|
118
119
|
- ext/rb_stb_image/extconf.rb
|
119
120
|
- ext/rb_stb_image/rb_stb_image.c
|