ruby-dnn 0.16.1 → 0.16.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/examples/vae.rb +118 -0
- data/lib/dnn/core/layers/math_layers.rb +24 -19
- data/lib/dnn/core/layers/merge_layers.rb +2 -2
- data/lib/dnn/core/losses.rb +9 -9
- data/lib/dnn/core/models.rb +8 -3
- data/lib/dnn/version.rb +1 -1
- metadata +3 -2
checksums.yaml
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
---
|
|
2
2
|
SHA256:
|
|
3
|
-
metadata.gz:
|
|
4
|
-
data.tar.gz:
|
|
3
|
+
metadata.gz: b76b0eb5bf75a22e48726f93fa4faff413c196b1f1587ce145aba7042c84a532
|
|
4
|
+
data.tar.gz: f32d09cb89391583f51a557e5f72024f618d943af3becc25f8f19905fa395b3b
|
|
5
5
|
SHA512:
|
|
6
|
-
metadata.gz:
|
|
7
|
-
data.tar.gz:
|
|
6
|
+
metadata.gz: f70784c49f71420df424c2077b430c0cb837b31b5493f62831cf49b296d834045796f2bae7749fa3b4bfb3ca365a97da856ccbebcc666c1489020eb55ec408fc
|
|
7
|
+
data.tar.gz: 52fd5cf850341b3d2bdd44da35f69254e508b0dd77e052a7217fe250a8d41a00d237b5ce0cb379e652bf2e6b7d0b99e95b3530ef5abe2d43a2b8f8c8f7dbca7e
|
data/examples/vae.rb
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
require "dnn"
|
|
2
|
+
require "dnn/datasets/mnist"
|
|
3
|
+
require "dnn/image"
|
|
4
|
+
require "numo/linalg/autoloader"
|
|
5
|
+
|
|
6
|
+
include DNN::Models
|
|
7
|
+
include DNN::Layers
|
|
8
|
+
include DNN::Optimizers
|
|
9
|
+
include DNN::Losses
|
|
10
|
+
|
|
11
|
+
x_train, y_train = DNN::MNIST.load_train
|
|
12
|
+
x_test, y_test = DNN::MNIST.load_test
|
|
13
|
+
|
|
14
|
+
x_train = Numo::SFloat.cast(x_train).reshape(x_train.shape[0], 784)
|
|
15
|
+
x_test = Numo::SFloat.cast(x_test).reshape(x_test.shape[0], 784)
|
|
16
|
+
|
|
17
|
+
x_train /= 255
|
|
18
|
+
x_test /= 255
|
|
19
|
+
|
|
20
|
+
y_train = DNN::Utils.to_categorical(y_train, 10, Numo::SFloat)
|
|
21
|
+
y_test = DNN::Utils.to_categorical(y_test, 10, Numo::SFloat)
|
|
22
|
+
|
|
23
|
+
$z_dim = 2
|
|
24
|
+
$z_mean = nil
|
|
25
|
+
$z_sigma = nil
|
|
26
|
+
|
|
27
|
+
class Sampling < MergeLayer
|
|
28
|
+
def forward(z_mean, z_sigma)
|
|
29
|
+
epsilon = DNN::Tensor.new(Numo::SFloat.new($z_dim).rand_norm(0, 1))
|
|
30
|
+
Tanh.(z_mean + z_sigma * epsilon)
|
|
31
|
+
end
|
|
32
|
+
end
|
|
33
|
+
|
|
34
|
+
class Encoder < Model
|
|
35
|
+
def initialize
|
|
36
|
+
super
|
|
37
|
+
@l1 = Dense.new(196)
|
|
38
|
+
@l2 = Dense.new(49)
|
|
39
|
+
@l3_1 = Dense.new($z_dim)
|
|
40
|
+
@l3_2 = Dense.new($z_dim)
|
|
41
|
+
@bn1 = BatchNormalization.new
|
|
42
|
+
@bn2 = BatchNormalization.new
|
|
43
|
+
end
|
|
44
|
+
|
|
45
|
+
def forward(x)
|
|
46
|
+
x = InputLayer.new(784).(x)
|
|
47
|
+
x = @l1.(x)
|
|
48
|
+
x = @bn1.(x)
|
|
49
|
+
x = ReLU.(x)
|
|
50
|
+
x = @l2.(x)
|
|
51
|
+
x = @bn2.(x)
|
|
52
|
+
x = ReLU.(x)
|
|
53
|
+
z_mean = @l3_1.(x)
|
|
54
|
+
z_sigma = @l3_2.(x)
|
|
55
|
+
[z_mean, z_sigma]
|
|
56
|
+
end
|
|
57
|
+
end
|
|
58
|
+
|
|
59
|
+
class Decoder < Model
|
|
60
|
+
def initialize
|
|
61
|
+
super
|
|
62
|
+
@l3 = Dense.new(196)
|
|
63
|
+
@l4 = Dense.new(784)
|
|
64
|
+
@bn1 = BatchNormalization.new
|
|
65
|
+
end
|
|
66
|
+
|
|
67
|
+
def forward(z)
|
|
68
|
+
x = @l3.(z)
|
|
69
|
+
x = @bn1.(x)
|
|
70
|
+
x = ReLU.(x)
|
|
71
|
+
x = @l4.(x)
|
|
72
|
+
x
|
|
73
|
+
end
|
|
74
|
+
end
|
|
75
|
+
|
|
76
|
+
class VAE < Model
|
|
77
|
+
attr_accessor :enc
|
|
78
|
+
attr_accessor :dec
|
|
79
|
+
|
|
80
|
+
def initialize(enc = nil, dec = nil)
|
|
81
|
+
super()
|
|
82
|
+
@enc = enc || Encoder.new
|
|
83
|
+
@dec = dec || Decoder.new
|
|
84
|
+
end
|
|
85
|
+
|
|
86
|
+
def forward(x)
|
|
87
|
+
z_mean, z_sigma = @enc.(x)
|
|
88
|
+
$z_mean, $z_sigma = z_mean, z_sigma
|
|
89
|
+
z = Sampling.(z_mean, z_sigma)
|
|
90
|
+
x = @dec.(z)
|
|
91
|
+
x
|
|
92
|
+
end
|
|
93
|
+
end
|
|
94
|
+
|
|
95
|
+
class VAELoss < Loss
|
|
96
|
+
def forward(y, t)
|
|
97
|
+
kl = -0.5 * Mean.(Sum.(1 + Log.($z_sigma**2) - $z_mean**2 - $z_sigma**2, axis: 1), axis: 0)
|
|
98
|
+
SigmoidCrossEntropy.(y, t) + kl
|
|
99
|
+
end
|
|
100
|
+
end
|
|
101
|
+
|
|
102
|
+
model = VAE.new
|
|
103
|
+
dec = model.dec
|
|
104
|
+
model.setup(Adam.new, VAELoss.new)
|
|
105
|
+
|
|
106
|
+
model.train(x_train, x_train, 10, batch_size: 100)
|
|
107
|
+
|
|
108
|
+
images = []
|
|
109
|
+
10.times do |i|
|
|
110
|
+
10.times do |j|
|
|
111
|
+
z1 = (i / 4.5) - 1
|
|
112
|
+
z2 = (j / 4.5) - 1
|
|
113
|
+
z = Numo::SFloat[z1, z2]
|
|
114
|
+
out = DNN::Utils.sigmoid(dec.predict1(z))
|
|
115
|
+
img = Numo::UInt8.cast(out * 255).reshape(28, 28, 1)
|
|
116
|
+
DNN::Image.write("img/img_#{i}_#{j}.png", img)
|
|
117
|
+
end
|
|
118
|
+
end
|
|
@@ -2,6 +2,8 @@ module DNN
|
|
|
2
2
|
module Layers
|
|
3
3
|
|
|
4
4
|
class Add < MergeLayer
|
|
5
|
+
include MergeLayerNode
|
|
6
|
+
|
|
5
7
|
def forward_node(x1, x2)
|
|
6
8
|
x1 + x2
|
|
7
9
|
end
|
|
@@ -12,6 +14,8 @@ module DNN
|
|
|
12
14
|
end
|
|
13
15
|
|
|
14
16
|
class Sub < MergeLayer
|
|
17
|
+
include MergeLayerNode
|
|
18
|
+
|
|
15
19
|
def forward_node(x1, x2)
|
|
16
20
|
x1 - x2
|
|
17
21
|
end
|
|
@@ -22,6 +26,8 @@ module DNN
|
|
|
22
26
|
end
|
|
23
27
|
|
|
24
28
|
class Mul < MergeLayer
|
|
29
|
+
include MergeLayerNode
|
|
30
|
+
|
|
25
31
|
def forward_node(x1, x2)
|
|
26
32
|
@x1, @x2 = x1, x2
|
|
27
33
|
x1 * x2
|
|
@@ -33,6 +39,8 @@ module DNN
|
|
|
33
39
|
end
|
|
34
40
|
|
|
35
41
|
class Div < MergeLayer
|
|
42
|
+
include MergeLayerNode
|
|
43
|
+
|
|
36
44
|
def forward_node(x1, x2)
|
|
37
45
|
@x1, @x2 = x1, x2
|
|
38
46
|
x1 / x2
|
|
@@ -46,6 +54,8 @@ module DNN
|
|
|
46
54
|
end
|
|
47
55
|
|
|
48
56
|
class Dot < MergeLayer
|
|
57
|
+
include MergeLayerNode
|
|
58
|
+
|
|
49
59
|
def forward_node(x1, x2)
|
|
50
60
|
@x1, @x2 = x1, x2
|
|
51
61
|
x1.dot(x2)
|
|
@@ -60,12 +70,11 @@ module DNN
|
|
|
60
70
|
include LayerNode
|
|
61
71
|
|
|
62
72
|
def forward_node(x)
|
|
63
|
-
@
|
|
64
|
-
Xumo::NMath.exp(x)
|
|
73
|
+
@y = Xumo::NMath.exp(x)
|
|
65
74
|
end
|
|
66
75
|
|
|
67
76
|
def backward_node(dy)
|
|
68
|
-
dy *
|
|
77
|
+
dy * @y
|
|
69
78
|
end
|
|
70
79
|
end
|
|
71
80
|
|
|
@@ -122,20 +131,16 @@ module DNN
|
|
|
122
131
|
end
|
|
123
132
|
|
|
124
133
|
def forward_node(x)
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
else
|
|
129
|
-
x.sum
|
|
130
|
-
end
|
|
134
|
+
@x_shape = x.shape
|
|
135
|
+
@dim = x.shape[@axis]
|
|
136
|
+
x.sum(axis: @axis, keepdims: true)
|
|
131
137
|
end
|
|
132
138
|
|
|
133
139
|
def backward_node(dy)
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
end
|
|
140
|
+
return dy if @x_shape == dy.shape
|
|
141
|
+
dx = dy
|
|
142
|
+
(@dim - 1).times do
|
|
143
|
+
dx = dx.concatenate(dy, axis: @axis)
|
|
139
144
|
end
|
|
140
145
|
dx
|
|
141
146
|
end
|
|
@@ -150,16 +155,16 @@ module DNN
|
|
|
150
155
|
end
|
|
151
156
|
|
|
152
157
|
def forward_node(x)
|
|
153
|
-
@
|
|
158
|
+
@x_shape = x.shape
|
|
159
|
+
@dim = x.shape[@axis]
|
|
154
160
|
x.mean(axis: @axis, keepdims: true)
|
|
155
161
|
end
|
|
156
162
|
|
|
157
163
|
def backward_node(dy)
|
|
164
|
+
return dy / @dim if @x_shape == dy.shape
|
|
158
165
|
dx = dy
|
|
159
|
-
|
|
160
|
-
(
|
|
161
|
-
dx = dx.concatenate(dy, axis: @axis)
|
|
162
|
-
end
|
|
166
|
+
(@dim - 1).times do
|
|
167
|
+
dx = dx.concatenate(dy, axis: @axis)
|
|
163
168
|
end
|
|
164
169
|
dx / @dim
|
|
165
170
|
end
|
|
@@ -26,8 +26,6 @@ module DNN
|
|
|
26
26
|
end
|
|
27
27
|
|
|
28
28
|
class MergeLayer < Layers::Layer
|
|
29
|
-
include MergeLayerNode
|
|
30
|
-
|
|
31
29
|
def self.call(x1, x2, *args)
|
|
32
30
|
new(*args).call(x1, x2)
|
|
33
31
|
end
|
|
@@ -45,6 +43,8 @@ module DNN
|
|
|
45
43
|
end
|
|
46
44
|
|
|
47
45
|
class Concatenate < MergeLayer
|
|
46
|
+
include MergeLayerNode
|
|
47
|
+
|
|
48
48
|
attr_reader :axis
|
|
49
49
|
|
|
50
50
|
def initialize(axis: 1)
|
data/lib/dnn/core/losses.rb
CHANGED
|
@@ -75,8 +75,8 @@ module DNN
|
|
|
75
75
|
0.5 * ((y - t)**2).mean(0).sum
|
|
76
76
|
end
|
|
77
77
|
|
|
78
|
-
def backward_node(
|
|
79
|
-
@y - @t
|
|
78
|
+
def backward_node(d)
|
|
79
|
+
(@y - @t) / @y.shape[0]
|
|
80
80
|
end
|
|
81
81
|
end
|
|
82
82
|
|
|
@@ -90,10 +90,10 @@ module DNN
|
|
|
90
90
|
end
|
|
91
91
|
|
|
92
92
|
def backward_node(d)
|
|
93
|
-
dy = @y - @t
|
|
93
|
+
dy = (@y - @t)
|
|
94
94
|
dy[dy >= 0] = 1
|
|
95
95
|
dy[dy < 0] = -1
|
|
96
|
-
dy
|
|
96
|
+
dy / @y.shape[0]
|
|
97
97
|
end
|
|
98
98
|
end
|
|
99
99
|
|
|
@@ -109,7 +109,7 @@ module DNN
|
|
|
109
109
|
def backward_node(d)
|
|
110
110
|
a = Xumo::SFloat.ones(*@a.shape)
|
|
111
111
|
a[@a <= 0] = 0
|
|
112
|
-
a * -@t
|
|
112
|
+
(a * -@t) / a.shape[0]
|
|
113
113
|
end
|
|
114
114
|
end
|
|
115
115
|
|
|
@@ -124,12 +124,12 @@ module DNN
|
|
|
124
124
|
end
|
|
125
125
|
|
|
126
126
|
def backward_node(d)
|
|
127
|
-
dy = @y - @t
|
|
127
|
+
dy = (@y - @t)
|
|
128
128
|
if @loss_value > 1
|
|
129
129
|
dy[dy >= 0] = 1
|
|
130
130
|
dy[dy < 0] = -1
|
|
131
131
|
end
|
|
132
|
-
dy
|
|
132
|
+
dy / @y.shape[0]
|
|
133
133
|
end
|
|
134
134
|
|
|
135
135
|
private
|
|
@@ -168,7 +168,7 @@ module DNN
|
|
|
168
168
|
end
|
|
169
169
|
|
|
170
170
|
def backward_node(d)
|
|
171
|
-
@x - @t
|
|
171
|
+
(@x - @t) / @x.shape[0]
|
|
172
172
|
end
|
|
173
173
|
|
|
174
174
|
def to_hash
|
|
@@ -205,7 +205,7 @@ module DNN
|
|
|
205
205
|
end
|
|
206
206
|
|
|
207
207
|
def backward_node(d)
|
|
208
|
-
@x - @t
|
|
208
|
+
(@x - @t) / @x.shape[0]
|
|
209
209
|
end
|
|
210
210
|
|
|
211
211
|
def to_hash
|
data/lib/dnn/core/models.rb
CHANGED
|
@@ -262,7 +262,7 @@ module DNN
|
|
|
262
262
|
DNN.learning_phase = true
|
|
263
263
|
out = call(Tensor.convert(x))
|
|
264
264
|
loss = @loss_func.loss(out, Tensor.convert(y), layers)
|
|
265
|
-
loss.link.backward(Xumo::SFloat.
|
|
265
|
+
loss.link.backward(Xumo::SFloat.ones(y[0...1, false].shape[0], 1))
|
|
266
266
|
@optimizer.update(get_all_trainable_params)
|
|
267
267
|
@last_log[:train_loss] = loss.data
|
|
268
268
|
call_callbacks(:after_train_on_batch)
|
|
@@ -353,7 +353,12 @@ module DNN
|
|
|
353
353
|
# @param [Numo::SFloat] x Input data. However, x is single data.
|
|
354
354
|
def predict1(x, use_loss_activation: true)
|
|
355
355
|
check_xy_type(x)
|
|
356
|
-
|
|
356
|
+
input = if x.is_a?(Array)
|
|
357
|
+
x.map { |v| v.reshape(1, *v.shape) }
|
|
358
|
+
else
|
|
359
|
+
x.reshape(1, *x.shape)
|
|
360
|
+
end
|
|
361
|
+
predict(input, use_loss_activation: use_loss_activation)[0, false]
|
|
357
362
|
end
|
|
358
363
|
|
|
359
364
|
# Add callback function.
|
|
@@ -458,7 +463,7 @@ module DNN
|
|
|
458
463
|
if !x.is_a?(Xumo::SFloat) && !x.is_a?(Array)
|
|
459
464
|
raise TypeError, "x:#{x.class.name} is not an instance of #{Xumo::SFloat.name} class or Array class."
|
|
460
465
|
end
|
|
461
|
-
if y && !y.is_a?(Xumo::SFloat) && !
|
|
466
|
+
if y && !y.is_a?(Xumo::SFloat) && !y.is_a?(Array)
|
|
462
467
|
raise TypeError, "y:#{y.class.name} is not an instance of #{Xumo::SFloat.name} class or Array class."
|
|
463
468
|
end
|
|
464
469
|
end
|
data/lib/dnn/version.rb
CHANGED
metadata
CHANGED
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
|
2
2
|
name: ruby-dnn
|
|
3
3
|
version: !ruby/object:Gem::Version
|
|
4
|
-
version: 0.16.
|
|
4
|
+
version: 0.16.2
|
|
5
5
|
platform: ruby
|
|
6
6
|
authors:
|
|
7
7
|
- unagiootoro
|
|
8
8
|
autorequire:
|
|
9
9
|
bindir: exe
|
|
10
10
|
cert_chain: []
|
|
11
|
-
date: 2020-01-
|
|
11
|
+
date: 2020-01-11 00:00:00.000000000 Z
|
|
12
12
|
dependencies:
|
|
13
13
|
- !ruby/object:Gem::Dependency
|
|
14
14
|
name: numo-narray
|
|
@@ -114,6 +114,7 @@ files:
|
|
|
114
114
|
- examples/pix2pix/dcgan.rb
|
|
115
115
|
- examples/pix2pix/imgen.rb
|
|
116
116
|
- examples/pix2pix/train.rb
|
|
117
|
+
- examples/vae.rb
|
|
117
118
|
- examples/xor_example.rb
|
|
118
119
|
- ext/rb_stb_image/extconf.rb
|
|
119
120
|
- ext/rb_stb_image/rb_stb_image.c
|